Copyright 2025, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
SPDX-License-Identifier: MIT

Wrap an Image Classification Model#

In this how-to, we will show you how to wrap a Torchvision ResNet50 to create a MAITE-compliant maite.protocols.image_classification.Model.

The general steps for wrapping a model are:

  • Understand the source (native) model

  • Create a wrapper that makes the source model conform to the MAITE model protocol (interface)

  • Verify that the wrapped model works correctly and has no static type checking errors

1 Load the Pretrained ResNet50 Model#

Load the required Python libaries:

import io
import json
import urllib.request
from typing import Callable, Sequence

import numpy as np
import PIL.Image
import torch as pt
import torchvision
from IPython.display import display

import maite.protocols.image_classification as ic
from maite.protocols import ArrayLike, ModelMetadata

%load_ext watermark
%watermark -iv -v
Python implementation: CPython
Python version       : 3.9.21
IPython version      : 8.18.1

maite      : 0.7.3
IPython    : 8.18.1
torchvision: 0.21.0
json       : 2.0.9
PIL        : 11.1.0
torch      : 2.6.0
numpy      : 1.26.4

Instantiate the ResNet50 model (pretrained on the ImageNet dataset), following the Torchvision documentation:

model_weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V2

model = torchvision.models.resnet50(
    weights=model_weights
)  # weights will download automatically

model = model.eval()  # set the ResNet50 model to eval mode
Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
0%|          | 0.00/97.8M [00:00<?, ?B/s]
15%|█▍        | 14.6M/97.8M [00:00<00:00, 153MB/s]
35%|███▌      | 34.5M/97.8M [00:00<00:00, 185MB/s]
53%|█████▎    | 52.2M/97.8M [00:00<00:00, 174MB/s]
71%|███████   | 69.0M/97.8M [00:00<00:00, 156MB/s]
86%|████████▋ | 84.4M/97.8M [00:00<00:00, 158MB/s]

100%|██████████| 97.8M/97.8M [00:00<00:00, 163MB/s]

2 Perform Model Inference on Sample Images#

Download the ImageNet labels and a couple of sample ImageNet images:

labels_url = "https://raw.githubusercontent.com/raghakot/keras-vis/refs/heads/master/resources/imagenet_class_index.json"
response = urllib.request.urlopen(labels_url).read()
labels = json.loads(response.decode("utf-8"))

img_url = "https://raw.githubusercontent.com/EliSchwartz/imagenet-sample-images/master/n01491361_tiger_shark.JPEG"
image_data = urllib.request.urlopen(img_url).read()
example_img_1 = PIL.Image.open(io.BytesIO(image_data))

img_url = "https://raw.githubusercontent.com/EliSchwartz/imagenet-sample-images/master/n01695060_Komodo_dragon.JPEG"
image_data = urllib.request.urlopen(img_url).read()
example_img_2 = PIL.Image.open(io.BytesIO(image_data))

example_imgs = [example_img_1, example_img_2]

The downloaded labels are a dictionary from label number strings to lists containing WordNet IDs and human readable labels:

dict(list(labels.items())[:4])
{'0': ['n01440764', 'tench'],
 '1': ['n01443537', 'goldfish'],
 '2': ['n01484850', 'great_white_shark'],
 '3': ['n01491361', 'tiger_shark']}

Next we check that the model works as expected on the sample images.

Note the weights for Torchvision models include a transforms() method that performs model-specific input transformations, such as resizing, interpolating, etc., as required by the model. It is important to remember that there is no standard way to do this and it can vary for every model.

def prediction_label(logits):
    logits = logits.unsqueeze(0)
    label_pred = logits.argmax().item()
    return f"{label_pred} {labels[str(label_pred)]}"

preprocess = model_weights.transforms()

for example_img in example_imgs:
    input = preprocess(example_img)
    logits = model(input.unsqueeze(0)) # use unsqueeze to add batch dimension

    print(
        f"""
    ResNet50 Outputs
    ================
    Result Type: {type(logits)}
    Result Shape: {logits.shape}
    Sample Prediction: {prediction_label(logits)}
    """
    )

    display(example_img)
ResNet50 Outputs
================
Result Type: <class 'torch.Tensor'>
Result Shape: torch.Size([1, 1000])
Sample Prediction: 3 ['n01491361', 'tiger_shark']
../_images/wrap_image_classification_model_12_1.png
ResNet50 Outputs
================
Result Type: <class 'torch.Tensor'>
Result Shape: torch.Size([1, 1000])
Sample Prediction: 48 ['n01695060', 'Komodo_dragon']
../_images/wrap_image_classification_model_12_3.png

3 Create the MAITE Model Wrapper#

A MAITE maite.protocols.image_classification.Model wrapper stores a reference to a model and model metadata. For this model we also store the preprocessing function.

A MAITE-compliant image classification Model need only implement the following method:

  • __call__(batch: Sequence[ArrayLike]) to make a model prediction for inputs in an input batch. It must transform its inputs from a Sequence[ArrayLike] to the format expected by the model, and transform the model outputs to a Sequence[ArrayLike] containing the predictions.

and an attribute:

  • metadata, which is a typed dictionary containing an id field for the model.

Note that MAITE requires the dimensions of each image in the input batch to be (C, H, W), which corresponds to the image’s color channels, height, and width, respectively.

Torchvision’s data preprocessing function, transforms(), mentioned in Section 2, already accepts torch.Tensors of shape (C, H, W), which are compatible with MAITE ArrayLike.

class TorchvisionResNetModel():
    def __init__(
        self,
        model: torchvision.models.ResNet,
        preprocess: Callable[[pt.Tensor], pt.Tensor],
    ) -> None:
        self.metadata: ModelMetadata = {
            "id": "Torchvision ResNet ImageNet 1k",
            "index2label": {
                int(idx): label for idx, [_wordnetid, label] in labels.items()
            },
        }
        self.model = model
        self.preprocess = preprocess

    def __call__(self, batch: Sequence[ArrayLike]) -> Sequence[pt.Tensor]:
        # Preprocess the inputs to ensure they match the model's input format
        imgs_chw = []
        for img_chw in batch:
            imgs_chw.append(self.preprocess(pt.as_tensor(img_chw)))

        # Create a shape-(N,C,H,W) tensor from the list of (C,H,W) tensors
        # Note: Images have been preprocessed to all have the same shape
        img_nchw = pt.stack(imgs_chw)

        # Call the model
        logits = self.model(img_nchw)

        # Convert the shape-(N,num_classes) logits tensor into a list of shape-(num_classes,) tensors
        return [t for t in logits]

# Wrap the Torchvision ResNet model
wrapped_model: ic.Model = TorchvisionResNetModel(model, preprocess)

4 Examine the MAITE-wrapped Model Output#

Build a batch of MAITE ArrayLike test images and visualize the wrapped model’s output:

def pil_image_to_maite(img_pil):
    # Convert the PIL image to a Numpy array
    img_hwc = np.array(img_pil)  # shape (H, W, C)

    # Use MAITE array index convention for representing images: (C, H, W)
    img_chw = img_hwc.transpose(2, 0, 1)

    return img_chw

batch = [pil_image_to_maite(i) for i in example_imgs]
predictions = wrapped_model(batch)

print(
    f"""
ResNet50 Outputs
================
Result Type: {type(predictions)}
Individual Prediction Type: {type(predictions[0])}
Individual Prediction Shape: {pt.as_tensor(predictions[0]).shape}
Sample Predictions:""")
for prediction, example_img in zip(predictions, example_imgs):
    print(f"    Predicted label = {prediction_label(prediction)}")
    display(example_img)
ResNet50 Outputs
================
Result Type: <class 'list'>
Individual Prediction Type: <class 'torch.Tensor'>
Individual Prediction Shape: torch.Size([1000])
Sample Predictions:
    Predicted label = 3 ['n01491361', 'tiger_shark']
../_images/wrap_image_classification_model_20_1.png
Predicted label = 48 ['n01695060', 'Komodo_dragon']
../_images/wrap_image_classification_model_20_3.png

5 Conclusion#

Wrapping image classification models with MAITE ensures interoperability and simplifies integration into test & evaluation (T&E) workflows that “know” how to work with MAITE models (since the model’s inputs and outputs are standardized). T&E workflows designed around MAITE protocols, such as MAITE’s evaluate, will work seamlessly across models, including ResNet50.

The key to model wrapping is to define the following:

  • a __call__ method that receives an input of type maite.protocols.image_classification.InputBatchType and returns a Sequence[ArrayLike].

  • a metadata typed dictionary attribute with at least an id field.