Copyright 2025, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
SPDX-License-Identifier: MIT
Wrap an Image Classification Model#
In this how-to, we will show you how to wrap a Torchvision ResNet50 to create a MAITE-compliant maite.protocols.image_classification.Model.
The general steps for wrapping a model are:
Understand the source (native) model
Create a wrapper that makes the source model conform to the MAITE model protocol (interface)
Verify that the wrapped model works correctly and has no static type checking errors
1 Load the Pretrained ResNet50 Model#
Load the required Python libaries:
import io
import json
import urllib.request
from typing import Callable, Sequence
import numpy as np
import PIL.Image
import torch as pt
import torchvision
from IPython.display import display
import maite.protocols.image_classification as ic
from maite.protocols import ArrayLike, ModelMetadata
%load_ext watermark
%watermark -iv -v
Python implementation: CPython
Python version : 3.9.21
IPython version : 8.18.1
maite : 0.7.3
IPython : 8.18.1
torchvision: 0.21.0
json : 2.0.9
PIL : 11.1.0
torch : 2.6.0
numpy : 1.26.4
Instantiate the ResNet50 model (pretrained on the ImageNet dataset), following the Torchvision documentation:
model_weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V2
model = torchvision.models.resnet50(
weights=model_weights
) # weights will download automatically
model = model.eval() # set the ResNet50 model to eval mode
Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
0%| | 0.00/97.8M [00:00<?, ?B/s]
15%|█▍ | 14.6M/97.8M [00:00<00:00, 153MB/s]
35%|███▌ | 34.5M/97.8M [00:00<00:00, 185MB/s]
53%|█████▎ | 52.2M/97.8M [00:00<00:00, 174MB/s]
71%|███████ | 69.0M/97.8M [00:00<00:00, 156MB/s]
86%|████████▋ | 84.4M/97.8M [00:00<00:00, 158MB/s]
100%|██████████| 97.8M/97.8M [00:00<00:00, 163MB/s]
2 Perform Model Inference on Sample Images#
Download the ImageNet labels and a couple of sample ImageNet images:
labels_url = "https://raw.githubusercontent.com/raghakot/keras-vis/refs/heads/master/resources/imagenet_class_index.json"
response = urllib.request.urlopen(labels_url).read()
labels = json.loads(response.decode("utf-8"))
img_url = "https://raw.githubusercontent.com/EliSchwartz/imagenet-sample-images/master/n01491361_tiger_shark.JPEG"
image_data = urllib.request.urlopen(img_url).read()
example_img_1 = PIL.Image.open(io.BytesIO(image_data))
img_url = "https://raw.githubusercontent.com/EliSchwartz/imagenet-sample-images/master/n01695060_Komodo_dragon.JPEG"
image_data = urllib.request.urlopen(img_url).read()
example_img_2 = PIL.Image.open(io.BytesIO(image_data))
example_imgs = [example_img_1, example_img_2]
The downloaded labels are a dictionary from label number strings to lists containing WordNet IDs and human readable labels:
{'0': ['n01440764', 'tench'],
'1': ['n01443537', 'goldfish'],
'2': ['n01484850', 'great_white_shark'],
'3': ['n01491361', 'tiger_shark']}
Next we check that the model works as expected on the sample images.
Note the weights for Torchvision models include a transforms()
method that performs model-specific input transformations, such as
resizing, interpolating, etc., as required by the model. It is important
to remember that there is no standard way to do this and it can vary for
every model.
def prediction_label(logits):
logits = logits.unsqueeze(0)
label_pred = logits.argmax().item()
return f"{label_pred} {labels[str(label_pred)]}"
preprocess = model_weights.transforms()
for example_img in example_imgs:
input = preprocess(example_img)
logits = model(input.unsqueeze(0)) # use unsqueeze to add batch dimension
print(
f"""
ResNet50 Outputs
================
Result Type: {type(logits)}
Result Shape: {logits.shape}
Sample Prediction: {prediction_label(logits)}
"""
)
display(example_img)
ResNet50 Outputs
================
Result Type: <class 'torch.Tensor'>
Result Shape: torch.Size([1, 1000])
Sample Prediction: 3 ['n01491361', 'tiger_shark']

ResNet50 Outputs
================
Result Type: <class 'torch.Tensor'>
Result Shape: torch.Size([1, 1000])
Sample Prediction: 48 ['n01695060', 'Komodo_dragon']

3 Create the MAITE Model Wrapper#
A MAITE maite.protocols.image_classification.Model wrapper stores a reference to a model and model metadata. For this model we also store the preprocessing function.
A MAITE-compliant image classification Model need only implement the following method:
__call__(batch: Sequence[ArrayLike])
to make a model prediction for inputs in an input batch. It must transform its inputs from aSequence[ArrayLike]
to the format expected by the model, and transform the model outputs to aSequence[ArrayLike]
containing the predictions.
and an attribute:
metadata
, which is a typed dictionary containing anid
field for the model.
Note that MAITE requires the dimensions of each image in the input batch
to be (C, H, W)
, which corresponds to the image’s color channels,
height, and width, respectively.
Torchvision’s data preprocessing function, transforms()
, mentioned
in Section 2, already accepts torch.Tensors
of shape (C, H, W)
,
which are compatible with MAITE ArrayLike
.
class TorchvisionResNetModel():
def __init__(
self,
model: torchvision.models.ResNet,
preprocess: Callable[[pt.Tensor], pt.Tensor],
) -> None:
self.metadata: ModelMetadata = {
"id": "Torchvision ResNet ImageNet 1k",
"index2label": {
int(idx): label for idx, [_wordnetid, label] in labels.items()
},
}
self.model = model
self.preprocess = preprocess
def __call__(self, batch: Sequence[ArrayLike]) -> Sequence[pt.Tensor]:
# Preprocess the inputs to ensure they match the model's input format
imgs_chw = []
for img_chw in batch:
imgs_chw.append(self.preprocess(pt.as_tensor(img_chw)))
# Create a shape-(N,C,H,W) tensor from the list of (C,H,W) tensors
# Note: Images have been preprocessed to all have the same shape
img_nchw = pt.stack(imgs_chw)
# Call the model
logits = self.model(img_nchw)
# Convert the shape-(N,num_classes) logits tensor into a list of shape-(num_classes,) tensors
return [t for t in logits]
# Wrap the Torchvision ResNet model
wrapped_model: ic.Model = TorchvisionResNetModel(model, preprocess)
4 Examine the MAITE-wrapped Model Output#
Build a batch of MAITE ArrayLike
test images and visualize the
wrapped model’s output:
def pil_image_to_maite(img_pil):
# Convert the PIL image to a Numpy array
img_hwc = np.array(img_pil) # shape (H, W, C)
# Use MAITE array index convention for representing images: (C, H, W)
img_chw = img_hwc.transpose(2, 0, 1)
return img_chw
batch = [pil_image_to_maite(i) for i in example_imgs]
predictions = wrapped_model(batch)
print(
f"""
ResNet50 Outputs
================
Result Type: {type(predictions)}
Individual Prediction Type: {type(predictions[0])}
Individual Prediction Shape: {pt.as_tensor(predictions[0]).shape}
Sample Predictions:""")
for prediction, example_img in zip(predictions, example_imgs):
print(f" Predicted label = {prediction_label(prediction)}")
display(example_img)
ResNet50 Outputs
================
Result Type: <class 'list'>
Individual Prediction Type: <class 'torch.Tensor'>
Individual Prediction Shape: torch.Size([1000])
Sample Predictions:
Predicted label = 3 ['n01491361', 'tiger_shark']

Predicted label = 48 ['n01695060', 'Komodo_dragon']

5 Conclusion#
Wrapping image classification models with MAITE ensures interoperability and simplifies integration into test & evaluation (T&E) workflows that “know” how to work with MAITE models (since the model’s inputs and outputs are standardized). T&E workflows designed around MAITE protocols, such as MAITE’s evaluate, will work seamlessly across models, including ResNet50.
The key to model wrapping is to define the following:
a
__call__
method that receives an input of typemaite.protocols.image_classification.InputBatchType
and returns aSequence[ArrayLike]
.a
metadata
typed dictionary attribute with at least anid
field.