Copyright 2025, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
SPDX-License-Identifier: MIT

Wrap an Object Detection Model#

This guide explains how to wrap common object detection models, such as the Torchvision Faster R-CNN and Ultralytics YOLOv8, to create models that conform to MAITE’s maite.protocols.object_detection.Model protocol.

The general steps for wrapping a model are:

  • Understand the source (native) model

  • Create a wrapper that makes the source model conform to the MAITE model protocol (interface)

  • Verify that the wrapped model works correctly and has no static type checking errors

1 Load the Pretrained Faster R-CNN and YOLOv8 Models#

First we load the required Python libraries:

import io
import urllib.request
from dataclasses import asdict, dataclass
from typing import Sequence

import numpy as np
import PIL.Image
import torch
from IPython.display import display
from torchvision.models.detection import (
    FasterRCNN,
    FasterRCNN_ResNet50_FPN_V2_Weights,
    fasterrcnn_resnet50_fpn_v2,
)
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import draw_bounding_boxes
from ultralytics import YOLO

import maite.protocols.object_detection as od
from maite.protocols import ArrayLike, ModelMetadata

%load_ext watermark
%watermark -iv -v
Creating new Ultralytics Settings v0.0.6 file ✅
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.
Python implementation: CPython
Python version       : 3.9.21
IPython version      : 8.18.1

maite      : 0.7.3
IPython    : 8.18.1
torchvision: 0.21.0
PIL        : 11.1.0
torch      : 2.6.0
ultralytics: 8.3.77
numpy      : 1.26.4

Next we instantiate the (native) Faster R-CNN and YOLOv8 models using the pretrained weights provided by their respective libraries. The models were trained on the COCO dataset.

Note that there are additional parameters you could pass into each model’s initialization function, which can affect its performance.

# Load R-CNN model
rcnn_weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
rcnn_model = fasterrcnn_resnet50_fpn_v2(weights=rcnn_weights, box_score_thresh=0.9)
rcnn_model.eval()  # set the RCNN to eval mode (defaults to training)

# Load YOLOv8 Nano model
yolov8_model = YOLO("yolov8n.pt")  # weights will download automatically
Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth
0%|          | 0.00/167M [00:00<?, ?B/s]
8%|▊         | 14.1M/167M [00:00<00:01, 148MB/s]
20%|█▉        | 33.2M/167M [00:00<00:00, 178MB/s]
30%|███       | 50.2M/167M [00:00<00:00, 176MB/s]
40%|████      | 67.1M/167M [00:00<00:00, 172MB/s]
50%|█████     | 83.6M/167M [00:00<00:00, 163MB/s]
59%|█████▉    | 99.4M/167M [00:00<00:00, 152MB/s]
69%|██████▊   | 115M/167M [00:00<00:00, 155MB/s]
80%|███████▉  | 133M/167M [00:00<00:00, 167MB/s]
91%|█████████ | 152M/167M [00:00<00:00, 175MB/s]

100%|██████████| 167M/167M [00:01<00:00, 169MB/s]

Downloading ultralytics/assets to 'yolov8n.pt'...
0%|          | 0.00/6.25M [00:00<?, ?B/s]

100%|██████████| 6.25M/6.25M [00:00<00:00, 117MB/s]

2 Perform Model Inference on Sample Images#

Object detection models vary in how they handle input data preprocessing. For example, the R-CNN model requires manual input transformations (e.g., resizing and normalization) and conversion to tensors prior to conducting inference. In contrast, the YOLOv8 model automatically resizes and normalizes inputs, which can be in a variety of formats including PIL.

We will first download two sample images.

img_url = "https://github.com/pytorch/vision/blob/main/test/assets/encode_jpeg/grace_hopper_517x606.jpg?raw=true"
image_data = urllib.request.urlopen(img_url).read()
pil_img_1 = PIL.Image.open(io.BytesIO(image_data))

img_url = "https://www.ultralytics.com/images/bus.jpg"
image_data = urllib.request.urlopen(img_url).read()
pil_img_2 = PIL.Image.open(io.BytesIO(image_data))

Next we preprocess the images for the R-CNN model and run inference:

# Convert to PyTorch tensors
tensor_img_1 = pil_to_tensor(pil_img_1)
tensor_img_2 = pil_to_tensor(pil_img_2)
rcnn_imgs = [tensor_img_1, tensor_img_2]

# Get the inference transforms assocated with these pretrained weights
preprocess = rcnn_weights.transforms()

# Apply inference preprocessing transforms
batch = [preprocess(img) for img in rcnn_imgs]

rcnn_preds = rcnn_model(batch)

Finally we run inference with the YOLO model on the PIL images:

yolo_imgs = [pil_img_1, pil_img_2]

yolo_preds = yolov8_model(yolo_imgs, verbose=False)

Similar to the input differences, notice the differences in output:

rcnn_pred = rcnn_preds[0]
yolo_pred = yolo_preds[0]

rcnn_boxes = rcnn_pred["boxes"]
yolo_boxes = yolo_pred.boxes

print(
    f"""
R-CNN Outputs
=============
Result Type: {type(rcnn_pred)}
Result Attributes: {rcnn_pred.keys()}
Box Types: {type(rcnn_boxes)}

YOLO Outputs
============
Result Type: {type(yolo_pred)}
Result Attributes: {yolo_pred._keys}
Box Types: {type(yolo_boxes)}
"""
)
R-CNN Outputs
=============
Result Type: <class 'dict'>
Result Attributes: dict_keys(['boxes', 'labels', 'scores'])
Box Types: <class 'torch.Tensor'>

YOLO Outputs
============
Result Type: <class 'ultralytics.engine.results.Results'>
Result Attributes: ('boxes', 'masks', 'probs', 'keypoints', 'obb')
Box Types: <class 'ultralytics.engine.results.Boxes'>

The R-CNN model returns a dictionary with certain keys, while the YOLO model returns a custom Results class. We now proceed to wrap both models with MAITE to get the benefits of standarizing model inputs and outputs.

3 Create MAITE Wrappers for the R-CNN and YOLOv8 Models#

We create two separate classes that implement the maite.protocols.object_detection.Model protocol. A MAITE object detection model only needs to have the following method:

  • __call__(input_batch: Sequence[ArrayLike]) to make model predictions for inputs in input batch.

and an attribute:

  • metadata, which is a typed dictionary containing an id field for the model plus an optional (but highly recommended) map from class indexes to labels called index2label.

Wrapping the models with MAITE provides a consistent interface for handling inputs and outputs (e.g., boxes and labels); this interoperabilty simplifies integration across diverse workflows and tools, e.g., downstream test & evaluation pipelines.

We begin by creating common input and output types to be used by both models, as well as an image rendering utility function.

MAITE object detection models must output a type compatible with maite.protocols.object_detection.ObjectDetectionTarget.

imgs: od.InputBatchType = [tensor_img_1, tensor_img_2]

@dataclass
class ObjectDetectionTargetImpl:
    boxes: np.ndarray
    labels: np.ndarray
    scores: np.ndarray

def render_wrapped_results(imgs, preds, model_metadata):
    names = model_metadata["index2label"]
    for img, pred in zip(imgs, preds):
        pred_labels = [names[label] for label in pred.labels]
        box = draw_bounding_boxes(
            img,
            boxes=torch.as_tensor(pred.boxes),
            labels=pred_labels,
            colors="red",
            width=4,
            font="DejaVuSans", # if necessary, change to TrueType font available on your system
            font_size=30,
        )
        im = to_pil_image(box.detach())
        h, w = im.size
        im = im.resize((h // 2, w // 2))
        display(im)

3.1 Wrap the R-CNN Model#

As mentioned in Section 2, the R-CNN model requires manual preprocessing of our current input data.

Since the input expected by the native model already conforms to maite.protocols.object_detection.InputBatchType as-is, we can perform the required preprocessing inside of our wrapper.

class WrappedRCNN:
    def __init__(
        self, model: FasterRCNN, weights: FasterRCNN_ResNet50_FPN_V2_Weights, id: str, **kwargs
    ):
        self.model = model
        self.weights = weights
        self.kwargs = kwargs

        # Add required model metadata attribute
        index2label = {i: category for i, category in enumerate(weights.meta["categories"])}
        self.metadata: ModelMetadata = {
            "id": id,
            "index2label": index2label
        }

    def __call__(self, batch: od.InputBatchType) -> Sequence[ObjectDetectionTargetImpl]:
        # Get MAITE inputs ready for native model
        preprocess = self.weights.transforms()
        batch = [preprocess(img) for img in batch]

        # Perform inference
        results = self.model(batch, **self.kwargs)

        # Restructure results to conform to MAITE
        all_detections = []
        for result in results:
            boxes = result["boxes"].detach().numpy()
            labels = result["labels"].detach().numpy()
            scores = result["scores"].detach().numpy()
            all_detections.append(
                ObjectDetectionTargetImpl(boxes=boxes, labels=labels, scores=scores)
            )

        return all_detections

wrapped_rcnn_model: od.Model = WrappedRCNN(rcnn_model, rcnn_weights, "TorchVision.FasterRCNN_ResNet50_FPN_V2")
wrapped_rcnn_preds = wrapped_rcnn_model(imgs)
wrapped_rcnn_preds
[ObjectDetectionTargetImpl(boxes=array([[     5.1904,       40.48,      513.56,      601.49],
        [     215.86,      414.02,      297.24,      482.01]], dtype=float32), labels=array([ 1, 32]), scores=array([    0.99896,     0.97159], dtype=float32)),
 ObjectDetectionTargetImpl(boxes=array([[     53.438,      401.78,      235.58,      907.62],
        [     21.099,       224.5,       802.2,      756.62],
        [     220.23,      404.51,      347.92,      860.75],
        [     668.02,      399.47,         810,       881.8],
        [    0.87665,      544.77,      75.116,      878.61]], dtype=float32), labels=array([1, 6, 1, 1, 1]), scores=array([    0.99916,     0.99915,     0.99888,     0.99819,     0.98551], dtype=float32))]
render_wrapped_results(imgs, wrapped_rcnn_preds, wrapped_rcnn_model.metadata)
../_images/wrap_object_detection_model_22_0.png ../_images/wrap_object_detection_model_22_1.png

3.2 Wrap the YOLO Model#

We previously passed PIL images to the (native) YOLO model, while the MAITE-compliant wrapper will be getting inputs of type maite.protocols.object_detection.InputBatchType (which is an alias for Sequence[ArrayLike]).

Note that MAITE requires the dimensions of each image in the input batch to be (C, H, W), which corresponds to the image’s color channels, height, and width, respectively.

YOLO models, however, expect the input data to be (H, W, C), so we will need to add an additional preprocessing step to tranpose the data.

Furthermore, the underlying YOLO models process different input formats (e.g., filename, PIL image, NumPy array, PyTorch tensor) differently, so we cast the underlying input as an np.ndarray for consistency with how PIL images are handled.

class WrappedYOLO:
    def __init__(self, model: YOLO, id: str, **kwargs):
        self.model = model
        self.kwargs = kwargs

        # Add required model metadata attribute
        self.metadata: ModelMetadata = {
            "id": id,
            "index2label": model.names # already a mapping from integer class index to string name
        }

    def __call__(self, batch: od.InputBatchType) -> Sequence[ObjectDetectionTargetImpl]:
        # Get MAITE inputs ready for native model
        # Bridge/convert input to np.ndarray and tranpose (C, H, W) -> (H, W, C)
        batch = [np.asarray(x).transpose((1, 2, 0)) for x in batch]

        # Perform inference
        results = self.model(batch, **self.kwargs)

        # Restructure results to conform to MAITE
        all_detections = []
        for result in results:
            detections = result.boxes
            if detections is None:
                continue
            detections = detections.cpu().numpy()
            boxes = np.asarray(detections.xyxy)
            labels = np.asarray(detections.cls, dtype=np.uint8)
            scores = np.asarray(detections.conf)
            all_detections.append(
                ObjectDetectionTargetImpl(boxes=boxes, labels=labels, scores=scores)
            )

        return all_detections

wrapped_yolov8_model: od.Model = WrappedYOLO(yolov8_model, id="Ultralytics.YOLOv8", verbose=False)

We can now visualize the wrapped model’s output for both images.

wrapped_yolo_preds = wrapped_yolov8_model(imgs)
wrapped_yolo_preds
[ObjectDetectionTargetImpl(boxes=array([[     9.0822,      20.553,         517,      604.41],
        [     218.77,      414.25,      300.06,      538.04],
        [     199.25,      414.16,      300.59,      605.61]], dtype=float32), labels=array([ 0, 27, 27], dtype=uint8), scores=array([    0.89559,     0.60285,     0.47824], dtype=float32)),
 ObjectDetectionTargetImpl(boxes=array([[     50.009,      397.59,      243.39,      905.03],
        [     223.38,      405.18,      344.82,      857.35],
        [     670.62,      378.47,      809.87,      875.35],
        [     29.816,      228.97,      797.11,      751.33],
        [    0.15492,       550.5,      61.616,      870.62]], dtype=float32), labels=array([0, 0, 0, 5, 0], dtype=uint8), scores=array([    0.87989,     0.87631,     0.86818,     0.84481,     0.44715], dtype=float32))]
render_wrapped_results(imgs, wrapped_yolo_preds, wrapped_yolov8_model.metadata)
../_images/wrap_object_detection_model_27_0.png ../_images/wrap_object_detection_model_27_1.png
# Get predictions from both models for first image in batch
wrapped_rcnn_pred = wrapped_rcnn_preds[0]
wrapped_yolo_pred = wrapped_yolo_preds[0]

wrapped_rcnn_fields = vars(wrapped_rcnn_pred).keys()
wrapped_yolo_fields = vars(wrapped_yolo_pred).keys()

print(
    f"""
Wrapped R-CNN Outputs
=====================
Result Type: {type(wrapped_rcnn_pred)}
Result Attributes: {wrapped_rcnn_fields}

YOLO Outputs
============
Result Type: {type(wrapped_yolo_pred)}
Result Attributes: {wrapped_yolo_fields}
"""
)
Wrapped R-CNN Outputs
=====================
Result Type: <class '__main__.ObjectDetectionTargetImpl'>
Result Attributes: dict_keys(['boxes', 'labels', 'scores'])

YOLO Outputs
============
Result Type: <class '__main__.ObjectDetectionTargetImpl'>
Result Attributes: dict_keys(['boxes', 'labels', 'scores'])

Notice that by wrapping both models to be MAITE-compliant, we were able to:

  • Use the same input data, imgs, for both models.

  • Create standardized output, conforming to ObjectDetectionTarget, from both models.

  • Render the results of both models using the same function, due to the standardized inputs and outputs.

4 Summary#

Wrapping object detection models with MAITE ensures interoperability and simplifies integration with additional T&E processes. By standardizing inputs and outputs, you can create consistent workflows that work seamlessly across models, like Faster R-CNN and YOLOv8.

The key to model wrapping is to define the following:

  • A __call__ method that receives a maite.protocols.object_detection.InputBatchType as input and returns maite.protocols.object_detection.TargetBatchType.

  • A metadata attribute that’s a typed dictionary with at least an “id field.