Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014). SPDX-License-Identifier: MIT

Hugging Face Image Classification Example#

The MAITE library provides interfaces for AI components such as datasets, models, metrics, and augmentations to make their use more consistent across test and evaluation (T&E) tools and workflows.

In this tutorial you will use MAITE, in conjunction with a set of common libraries, to:

  • Wrap an image classification dataset from Hugging Face (CIFAR-10),

  • Wrap an image classification model from Hugging Face (Vision Transformer),

  • Wrap a metric from TorchMetrics (multiclass accuracy), and

  • Compute performance on the clean dataset using MAITE’s evaluate workflow utility.

Once complete, you will have a basic understanding of MAITE’s interfaces for datasets, models, and metrics, as well as how to use MAITE’s native API for running evaluations.

This tutorial does not assume any prior knowledge, but some experience with Python, machine learning, and the PyTorch framework may be helpful.

Getting Started#

This tutorial uses MAITE, PyTorch, Torchvision, TorchMetrics, Hugging Face datasets and transformers, and Matplotlib.

For running this notebook on your local machine, you can use the following commands to create a conda environment with the required dependencies:

conda create --name hf_image_classification python=3.10 pip
conda activate hf_image_classification
pip install maite datasets jupyter matplotlib torch torchmetrics torchvision transformers watermark

Now that you have an environment, we import the necessary libraries:

import datasets
import maite.protocols.image_classification as ic
import matplotlib.pyplot as plt
import numpy as np
import torch

from maite.protocols import ArrayLike, DatasetMetadata, MetricMetadata, ModelMetadata
from maite.workflows import evaluate
from torchmetrics import Accuracy, Metric
from torchvision.transforms.functional import to_tensor, resize
from transformers import AutoModelForImageClassification, ViTForImageClassification
from typing import Any, Optional, Sequence
from watermark import watermark
print("This notebook was executed with the following:\n")
print(watermark(python=True, packages="datasets,jupyter,matplotlib,numpy,torch,torchmetrics,torchvision,transformers,watermark"))
This notebook was executed with the following:

Python implementation: CPython
Python version       : 3.9.21
IPython version      : 8.18.1

datasets    : 3.3.1
jupyter     : 1.1.1
matplotlib  : 3.9.4
numpy       : 1.26.4
torch       : 2.6.0
torchmetrics: 1.6.1
torchvision : 0.21.0
transformers: 4.49.0
watermark   : 2.5.0

Wrapping a Hugging Face Dataset#

We’ll be working with a common computer vision benchmark dataset called CIFAR-10, which consists of color images (size 32 x 32 pixels) covering 10 classes (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck). The dataset is available through the Hugging Face datasets library, which provides access to officially curated datasets as well as datasets contributed to Hugging Face Hub from the machine learning community.

First we load a subset of the “native” Hugging Face dataset:

subset_size = 256
hf_dataset: datasets.Dataset = datasets.load_dataset("cifar10", split=f"test[:{subset_size}]") # type: ignore
README.md:   0%|          | 0.00/5.16k [00:00<?, ?B/s]
train-00000-of-00001.parquet:   0%|          | 0.00/120M [00:00<?, ?B/s]
test-00000-of-00001.parquet:   0%|          | 0.00/23.9M [00:00<?, ?B/s]
Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]
Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Next, we wrap the dataset so it can be used with MAITE.

In order to facilitate executing T&E workflows with datasets from difference sources (e.g., existing libraries like Torchvision or Hugging Face or custom datasets), MAITE provides a Dataset protocol that specifies the expected interface (i.e, a minimal set of required attributes, methods, and method type signatures).

At a high level, a MAITE image classification dataset needs to have two methods (__len__ and __getitem__) and return the image, target (label/class), and metadata associated with a requested dataset index. The dataset also needs to have a metadata attribute containing some basic metadata (at least an id field).

The following wrapper internally converts from the “native” format of the dataset to types compatible with MAITE:

class HuggingFaceDataset:
    def __init__(self, hf_dataset: datasets.Dataset, id: str, index2label: dict[int, str], resize_shape: Optional[list[int]] = None):
        self.hf_dataset = hf_dataset
        self.num_classes = hf_dataset.features["label"].num_classes
        self.resize_shape = resize_shape

        # Create required dataset metadata attribute
        self.metadata: DatasetMetadata = DatasetMetadata(id=id, index2label=index2label)

    def __len__(self) -> int:
        return len(self.hf_dataset)

    def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, ic.DatumMetadataType]:
        if index < 0 or index >= len(self):
            raise IndexError(f"Index {index} is out of range for the dataset, which has length {len(self)}.")

        # Get the PIL image and integer label from the base HF dataset element (which is a dictionary)
        item = self.hf_dataset[index]
        img_pil = item["img"]
        label = item["label"]

        # Convert the PIL image to a PyTorch tensor for compatibility with PyTorch libraries
        img_pt = to_tensor(img_pil)

        # Apply resizing if requested
        if self.resize_shape is not None:
            img_pt = resize(img_pt, self.resize_shape)

        # Create one-hot encoded tensor with true class label for this image
        target = torch.zeros(self.num_classes)
        target[label] = 1

        return img_pt, target, ic.DatumMetadataType(id=index)

We now create an instance of the MAITE complient version of the Hugging Face dataset.

Note that the dataset variable has ic.Dataset as the type hint. If your environment has a static type checker enabled (e.g., the Pyright type checker via the Pylance language server in VS Code), then the type checker will verify that our wrapped dataset conforms to the protocol and indicate a problem if not (e.g., by underlining with a red squiggle).

# Create map from integer class index to string label
num_classes = hf_dataset.features["label"].num_classes
index2label = {i: hf_dataset.features["label"].int2str(i) for i in range(num_classes)}

# Wrap dataset
wrapped_hf_dataset: ic.Dataset = HuggingFaceDataset(
    hf_dataset,
    id="CIFAR-10",
    index2label=index2label,
    resize_shape=[224, 224]
)

print(f"{len(wrapped_hf_dataset) = }")
len(wrapped_hf_dataset) = 256

Here are some sample CIFAR-10 images along with their ground truth labels:

ncols = 6
fig, ax = plt.subplots(1, ncols, figsize=(6, 2))
for i in range(ncols):
    # Get datum i
    img, label_onehot, md = wrapped_hf_dataset[i]

    # Convert to NumPy array in height, width, color channel (HWC) order (for display with matplotlib)
    img_np = np.asarray(img).transpose(1, 2, 0)

    # Get ground truth class index and label
    index = torch.as_tensor(label_onehot).argmax().item()
    label = index2label[int(index)]

    # Plot image with label
    ax[i].axis("off")
    ax[i].imshow(img_np)
    ax[i].set_title(label)

fig.tight_layout()
../_images/hf_image_classification_12_0.png

Wrapping a Hugging Face Model#

In this section, we’ll wrap a Hugging Face Vision Transformer (ViT) classification model that is available through Hugging Face Hub. The model has been trained on ImageNet-21k and fine-tuned on the CIFAR-10 dataset.

First we load the “native” Hugging Face model:

hf_model: ViTForImageClassification = AutoModelForImageClassification.from_pretrained(
    "aaraki/vit-base-patch16-224-in21k-finetuned-cifar10"
)
config.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]
pytorch_model.bin:   0%|          | 0.00/343M [00:00<?, ?B/s]

Next we wrap the model to conform to the MAITE ic.Model protocol, which requires a __call__ method that takes a batch of inputs and returns a batch of predictions. The model also needs to have a metadata attribute containing some basic metadata (at least an id field).

class HuggingFaceModel:
    def __init__(self, hf_model: ViTForImageClassification, id: str, index2label: dict[int, str], device: str = "cpu"):
        self.hf_model = hf_model
        self.device = device

        # Create required model metadata attribute
        self.metadata: ModelMetadata = ModelMetadata(id=id, index2label=index2label)

        # Move the model to requested device and set to eval mode
        self.hf_model.to(device) # type: ignore
        self.hf_model.eval()

    def __call__(self, batch: Sequence[ArrayLike]) -> Sequence[torch.Tensor]:
        # Combine inputs into PyTorch tensor of shape-(N,C,H,W) (batch size, color channels, height, width)
        batch_pt = torch.stack([torch.as_tensor(x) for x in batch])

        # Move tensor to the desired device
        batch_pt = batch_pt.to(self.device)

        # Apply model to batch (NOTE: preprocessing not needed for this particular HF model)
        output = self.hf_model(batch_pt)

        # Restructure to expected output format (sequence of probability/logit vectors)
        result = [x for x in output.logits.detach().cpu()]
        return result
wrapped_hf_model: ic.Model = HuggingFaceModel(
    hf_model,
    id="vit-base-patch16-224-in21k-finetuned-cifar10",
    index2label=index2label
)

For an initial test, we’ll manually create an input batch and perform inference on it with the wrapped model:

# Create batch with single image
i = 0
x, y, md = wrapped_hf_dataset[i]
xb, yb, mdb = [x], [y], [md]

# Apply model and get first (only) prediction of size-1 batch of results
preds = wrapped_hf_model(xb)[0]
y_hat = torch.as_tensor(preds).argmax().item()

# Plot image with model prediction
fig, ax = plt.subplots(figsize=(1.5, 1.5))
img_np = np.asarray(x).transpose(1, 2, 0)
ax.axis("off")
ax.imshow(img_np)
ax.set_title(f"pred: {index2label[int(y_hat)]}")
fig.tight_layout()
../_images/hf_image_classification_20_0.png

We see that the model predicts the correct class for the first example. But we’d like to perform a more quantitative evaluation across a larger set of images.

Metrics#

In this section we wrap a TorchMetrics metric to conform to the MAITE ic.Metric protocol.

First we create a “native” TorchMetrics accuracy metric:

tm_acc: Metric = Accuracy(task="multiclass", num_classes=10)

Next we wrap the metric as a MAITE ic.Metric that has the required update, compute, and reset methods, as well as the required metadata attribute:

class TorchMetricsClassificationMetric:
    def __init__(self, tm_metric: Metric, name: str, device: str = "cpu"):
        self.tm_metric = tm_metric
        self.name = name
        self.device = device

        # Create required metric metadata attribute
        self.metadata: MetricMetadata = MetricMetadata(id=name)

    def reset(self):
        self.tm_metric.reset()

    def update(self, preds: Sequence[ArrayLike], targets: Sequence[ArrayLike]) -> None:
        # Convert inputs to PyTorch tensors of shape-(N, num_classes)
        preds_pt: torch.Tensor = torch.stack([torch.as_tensor(x) for x in preds]).to(self.device)
        assert preds_pt.ndim == 2

        targets_pt: torch.Tensor = torch.stack([torch.as_tensor(x) for x in targets]).to(self.device)
        assert targets_pt.ndim == 2

        # Convert probabilities/logits to predicted class indices and update native TorchMetrics metric
        self.tm_metric.update(preds_pt.argmax(dim=1), targets_pt.argmax(dim=1))

    def compute(self) -> dict[str, Any]:
        result = {}
        result[self.name] = self.tm_metric.compute()
        return result
wrapped_tm_acc: ic.Metric = TorchMetricsClassificationMetric(tm_acc, "accuracy")

Workflows#

Now we’ll run MAITE’s evaluate workflow, which manages the process of performing model inference on the dataset and computing the desired metric.

results, _, _ = evaluate(
    dataset=wrapped_hf_dataset,
    model=wrapped_hf_model,
    metric=wrapped_tm_acc
)

results
0%|          | 0/256 [00:00<?, ?it/s]
model.safetensors:   0%|          | 0.00/343M [00:00<?, ?B/s]
{'accuracy': tensor(0.9531)}

We see that the model performs very well on this dataset, achieving an accuracy of over 95%.

Congratulations! You have now successfully used MAITE to wrap a dataset, model, and metric from various libraries, and run an evaluation to compute the performance of the pretrained model on a subset of the CIFAR-10 test split.