:: Copyright 2025, MASSACHUSETTS INSTITUTE OF TECHNOLOGY Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014). SPDX-License-Identifier: MIT Wrap an Image Classification Model ================================== In this how-to, we will show you how to wrap a `Torchvision ResNet50 `__ to create a MAITE-compliant `maite.protocols.image_classification.Model <../generated/maite.protocols.image_classification.Model.html>`__. The general steps for wrapping a model are: - Understand the source (native) model - Create a wrapper that makes the source model conform to the MAITE model protocol (interface) - Verify that the wrapped model works correctly and has no static type checking errors 1 Load the Pretrained ResNet50 Model ------------------------------------ Load the required Python libaries: .. code:: ipython3 import io import json import urllib.request from typing import Callable, Sequence import numpy as np import PIL.Image import torch as pt import torchvision from IPython.display import display import maite.protocols.image_classification as ic from maite.protocols import ArrayLike, ModelMetadata %load_ext watermark %watermark -iv -v .. parsed-literal:: Python implementation: CPython Python version : 3.9.21 IPython version : 8.18.1 maite : 0.7.3 IPython : 8.18.1 torchvision: 0.21.0 json : 2.0.9 PIL : 11.1.0 torch : 2.6.0 numpy : 1.26.4 Instantiate the ResNet50 model (pretrained on the `ImageNet dataset `__), following the `Torchvision documentation `__: .. code:: ipython3 model_weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V2 model = torchvision.models.resnet50( weights=model_weights ) # weights will download automatically model = model.eval() # set the ResNet50 model to eval mode .. parsed-literal:: Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth .. parsed-literal:: 0%| | 0.00/97.8M [00:00`__ and human readable labels: .. code:: ipython3 dict(list(labels.items())[:4]) .. parsed-literal:: {'0': ['n01440764', 'tench'], '1': ['n01443537', 'goldfish'], '2': ['n01484850', 'great_white_shark'], '3': ['n01491361', 'tiger_shark']} Next we check that the model works as expected on the sample images. Note the weights for Torchvision models include a ``transforms()`` method that performs model-specific input transformations, such as resizing, interpolating, etc., as required by the model. It is important to remember that there is no standard way to do this and it can vary for every model. .. code:: ipython3 def prediction_label(logits): logits = logits.unsqueeze(0) label_pred = logits.argmax().item() return f"{label_pred} {labels[str(label_pred)]}" preprocess = model_weights.transforms() for example_img in example_imgs: input = preprocess(example_img) logits = model(input.unsqueeze(0)) # use unsqueeze to add batch dimension print( f""" ResNet50 Outputs ================ Result Type: {type(logits)} Result Shape: {logits.shape} Sample Prediction: {prediction_label(logits)} """ ) display(example_img) .. parsed-literal:: ResNet50 Outputs ================ Result Type: Result Shape: torch.Size([1, 1000]) Sample Prediction: 3 ['n01491361', 'tiger_shark'] .. image:: wrap_image_classification_model_files/wrap_image_classification_model_12_1.png .. parsed-literal:: ResNet50 Outputs ================ Result Type: Result Shape: torch.Size([1, 1000]) Sample Prediction: 48 ['n01695060', 'Komodo_dragon'] .. image:: wrap_image_classification_model_files/wrap_image_classification_model_12_3.png 3 Create the MAITE Model Wrapper -------------------------------- A MAITE `maite.protocols.image_classification.Model <../generated/maite.protocols.image_classification.Model.html>`__ wrapper stores a reference to a model and model metadata. For this model we also store the preprocessing function. A MAITE-compliant image classification Model need only implement the following method: - ``__call__(batch: Sequence[ArrayLike])`` to make a model prediction for inputs in an input batch. It must transform its inputs from a ``Sequence[ArrayLike]`` to the format expected by the model, and transform the model outputs to a ``Sequence[ArrayLike]`` containing the predictions. and an attribute: - ``metadata``, which is a typed dictionary containing an ``id`` field for the model. Note that MAITE requires the dimensions of each image in the input batch to be ``(C, H, W)``, which corresponds to the image’s color channels, height, and width, respectively. Torchvision’s data preprocessing function, ``transforms()``, mentioned in Section 2, already accepts ``torch.Tensors`` of shape ``(C, H, W)``, which are compatible with MAITE ``ArrayLike``. .. code:: ipython3 class TorchvisionResNetModel(): def __init__( self, model: torchvision.models.ResNet, preprocess: Callable[[pt.Tensor], pt.Tensor], ) -> None: self.metadata: ModelMetadata = { "id": "Torchvision ResNet ImageNet 1k", "index2label": { int(idx): label for idx, [_wordnetid, label] in labels.items() }, } self.model = model self.preprocess = preprocess def __call__(self, batch: Sequence[ArrayLike]) -> Sequence[pt.Tensor]: # Preprocess the inputs to ensure they match the model's input format imgs_chw = [] for img_chw in batch: imgs_chw.append(self.preprocess(pt.as_tensor(img_chw))) # Create a shape-(N,C,H,W) tensor from the list of (C,H,W) tensors # Note: Images have been preprocessed to all have the same shape img_nchw = pt.stack(imgs_chw) # Call the model logits = self.model(img_nchw) # Convert the shape-(N,num_classes) logits tensor into a list of shape-(num_classes,) tensors return [t for t in logits] # Wrap the Torchvision ResNet model wrapped_model: ic.Model = TorchvisionResNetModel(model, preprocess) 4 Examine the MAITE-wrapped Model Output ---------------------------------------- Build a batch of MAITE ``ArrayLike`` test images and visualize the wrapped model’s output: .. code:: ipython3 def pil_image_to_maite(img_pil): # Convert the PIL image to a Numpy array img_hwc = np.array(img_pil) # shape (H, W, C) # Use MAITE array index convention for representing images: (C, H, W) img_chw = img_hwc.transpose(2, 0, 1) return img_chw batch = [pil_image_to_maite(i) for i in example_imgs] .. code:: ipython3 predictions = wrapped_model(batch) print( f""" ResNet50 Outputs ================ Result Type: {type(predictions)} Individual Prediction Type: {type(predictions[0])} Individual Prediction Shape: {pt.as_tensor(predictions[0]).shape} Sample Predictions:""") for prediction, example_img in zip(predictions, example_imgs): print(f" Predicted label = {prediction_label(prediction)}") display(example_img) .. parsed-literal:: ResNet50 Outputs ================ Result Type: Individual Prediction Type: Individual Prediction Shape: torch.Size([1000]) Sample Predictions: Predicted label = 3 ['n01491361', 'tiger_shark'] .. image:: wrap_image_classification_model_files/wrap_image_classification_model_20_1.png .. parsed-literal:: Predicted label = 48 ['n01695060', 'Komodo_dragon'] .. image:: wrap_image_classification_model_files/wrap_image_classification_model_20_3.png 5 Conclusion ------------ Wrapping image classification models with MAITE ensures interoperability and simplifies integration into test & evaluation (T&E) workflows that “know” how to work with MAITE models (since the model’s inputs and outputs are standardized). T&E workflows designed around MAITE protocols, such as MAITE’s `evaluate <../generated/maite.workflows.evaluate.html>`__, will work seamlessly across models, including ResNet50. The key to model wrapping is to define the following: - a ``__call__`` method that receives an input of type ``maite.protocols.image_classification.InputBatchType`` and returns a ``Sequence[ArrayLike]``. - a ``metadata`` typed dictionary attribute with at least an ``id`` field.