:: Copyright 2025, MASSACHUSETTS INSTITUTE OF TECHNOLOGY Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014). SPDX-License-Identifier: MIT Wrap an Object Detection Model ============================== This guide explains how to wrap common object detection models, such as the `Torchvision Faster R-CNN `__ and `Ultralytics YOLOv8 `__, to create models that conform to MAITE’s `maite.protocols.object_detection.Model <../generated/maite.protocols.object_detection.Model.html>`__ protocol. The general steps for wrapping a model are: - Understand the source (native) model - Create a wrapper that makes the source model conform to the MAITE model protocol (interface) - Verify that the wrapped model works correctly and has no static type checking errors 1 Load the Pretrained Faster R-CNN and YOLOv8 Models ---------------------------------------------------- First we load the required Python libraries: .. code:: ipython3 import io import urllib.request from dataclasses import asdict, dataclass from typing import Sequence import numpy as np import PIL.Image import torch from IPython.display import display from torchvision.models.detection import ( FasterRCNN, FasterRCNN_ResNet50_FPN_V2_Weights, fasterrcnn_resnet50_fpn_v2, ) from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import draw_bounding_boxes from ultralytics import YOLO import maite.protocols.object_detection as od from maite.protocols import ArrayLike, ModelMetadata %load_ext watermark %watermark -iv -v .. parsed-literal:: Creating new Ultralytics Settings v0.0.6 file ✅ View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json' Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings. .. parsed-literal:: Python implementation: CPython Python version : 3.9.21 IPython version : 8.18.1 maite : 0.7.3 IPython : 8.18.1 torchvision: 0.21.0 PIL : 11.1.0 torch : 2.6.0 ultralytics: 8.3.77 numpy : 1.26.4 Next we instantiate the (native) Faster R-CNN and YOLOv8 models using the pretrained weights provided by their respective libraries. The models were trained on the `COCO dataset `__. Note that there are additional parameters you could pass into each model’s initialization function, which can affect its performance. .. code:: ipython3 # Load R-CNN model rcnn_weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT rcnn_model = fasterrcnn_resnet50_fpn_v2(weights=rcnn_weights, box_score_thresh=0.9) rcnn_model.eval() # set the RCNN to eval mode (defaults to training) # Load YOLOv8 Nano model yolov8_model = YOLO("yolov8n.pt") # weights will download automatically .. parsed-literal:: Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth .. parsed-literal:: 0%| | 0.00/167M [00:00 Result Attributes: dict_keys(['boxes', 'labels', 'scores']) Box Types: YOLO Outputs ============ Result Type: Result Attributes: ('boxes', 'masks', 'probs', 'keypoints', 'obb') Box Types: The R-CNN model returns a dictionary with certain keys, while the YOLO model returns a custom Results class. We now proceed to wrap both models with MAITE to get the benefits of standarizing model inputs and outputs. 3 Create MAITE Wrappers for the R-CNN and YOLOv8 Models ------------------------------------------------------- We create two separate classes that implement the `maite.protocols.object_detection.Model <../generated/maite.protocols.object_detection.Model.html>`__ protocol. A MAITE object detection model only needs to have the following method: - ``__call__(input_batch: Sequence[ArrayLike])`` to make model predictions for inputs in input batch. and an attribute: - ``metadata``, which is a typed dictionary containing an ``id`` field for the model plus an optional (but highly recommended) map from class indexes to labels called ``index2label``. Wrapping the models with MAITE provides a consistent interface for handling inputs and outputs (e.g., boxes and labels); this interoperabilty simplifies integration across diverse workflows and tools, e.g., downstream test & evaluation pipelines. We begin by creating common input and output types to be used by *both* models, as well as an image rendering utility function. MAITE object detection models must output a type compatible with ``maite.protocols.object_detection.ObjectDetectionTarget``. .. code:: ipython3 imgs: od.InputBatchType = [tensor_img_1, tensor_img_2] @dataclass class ObjectDetectionTargetImpl: boxes: np.ndarray labels: np.ndarray scores: np.ndarray def render_wrapped_results(imgs, preds, model_metadata): names = model_metadata["index2label"] for img, pred in zip(imgs, preds): pred_labels = [names[label] for label in pred.labels] box = draw_bounding_boxes( img, boxes=torch.as_tensor(pred.boxes), labels=pred_labels, colors="red", width=4, font="DejaVuSans", # if necessary, change to TrueType font available on your system font_size=30, ) im = to_pil_image(box.detach()) h, w = im.size im = im.resize((h // 2, w // 2)) display(im) 3.1 Wrap the R-CNN Model ~~~~~~~~~~~~~~~~~~~~~~~~ As mentioned in Section 2, the R-CNN model requires manual preprocessing of our current input data. Since the input expected by the native model already conforms to ``maite.protocols.object_detection.InputBatchType`` as-is, we can perform the required preprocessing inside of our wrapper. .. code:: ipython3 class WrappedRCNN: def __init__( self, model: FasterRCNN, weights: FasterRCNN_ResNet50_FPN_V2_Weights, id: str, **kwargs ): self.model = model self.weights = weights self.kwargs = kwargs # Add required model metadata attribute index2label = {i: category for i, category in enumerate(weights.meta["categories"])} self.metadata: ModelMetadata = { "id": id, "index2label": index2label } def __call__(self, batch: od.InputBatchType) -> Sequence[ObjectDetectionTargetImpl]: # Get MAITE inputs ready for native model preprocess = self.weights.transforms() batch = [preprocess(img) for img in batch] # Perform inference results = self.model(batch, **self.kwargs) # Restructure results to conform to MAITE all_detections = [] for result in results: boxes = result["boxes"].detach().numpy() labels = result["labels"].detach().numpy() scores = result["scores"].detach().numpy() all_detections.append( ObjectDetectionTargetImpl(boxes=boxes, labels=labels, scores=scores) ) return all_detections wrapped_rcnn_model: od.Model = WrappedRCNN(rcnn_model, rcnn_weights, "TorchVision.FasterRCNN_ResNet50_FPN_V2") .. code:: ipython3 wrapped_rcnn_preds = wrapped_rcnn_model(imgs) wrapped_rcnn_preds .. parsed-literal:: [ObjectDetectionTargetImpl(boxes=array([[ 5.1904, 40.48, 513.56, 601.49], [ 215.86, 414.02, 297.24, 482.01]], dtype=float32), labels=array([ 1, 32]), scores=array([ 0.99896, 0.97159], dtype=float32)), ObjectDetectionTargetImpl(boxes=array([[ 53.438, 401.78, 235.58, 907.62], [ 21.099, 224.5, 802.2, 756.62], [ 220.23, 404.51, 347.92, 860.75], [ 668.02, 399.47, 810, 881.8], [ 0.87665, 544.77, 75.116, 878.61]], dtype=float32), labels=array([1, 6, 1, 1, 1]), scores=array([ 0.99916, 0.99915, 0.99888, 0.99819, 0.98551], dtype=float32))] .. code:: ipython3 render_wrapped_results(imgs, wrapped_rcnn_preds, wrapped_rcnn_model.metadata) .. image:: wrap_object_detection_model_files/wrap_object_detection_model_22_0.png .. image:: wrap_object_detection_model_files/wrap_object_detection_model_22_1.png 3.2 Wrap the YOLO Model ~~~~~~~~~~~~~~~~~~~~~~~ We previously passed PIL images to the (native) YOLO model, while the MAITE-compliant wrapper will be getting inputs of type ``maite.protocols.object_detection.InputBatchType`` (which is an alias for ``Sequence[ArrayLike]``). Note that MAITE requires the dimensions of each image in the input batch to be ``(C, H, W)``, which corresponds to the image’s color channels, height, and width, respectively. YOLO models, however, expect the input data to be ``(H, W, C)``, so we will need to add an additional preprocessing step to tranpose the data. Furthermore, the underlying YOLO models process different input formats (e.g., filename, PIL image, NumPy array, PyTorch tensor) differently, so we cast the underlying input as an ``np.ndarray`` for consistency with how PIL images are handled. .. code:: ipython3 class WrappedYOLO: def __init__(self, model: YOLO, id: str, **kwargs): self.model = model self.kwargs = kwargs # Add required model metadata attribute self.metadata: ModelMetadata = { "id": id, "index2label": model.names # already a mapping from integer class index to string name } def __call__(self, batch: od.InputBatchType) -> Sequence[ObjectDetectionTargetImpl]: # Get MAITE inputs ready for native model # Bridge/convert input to np.ndarray and tranpose (C, H, W) -> (H, W, C) batch = [np.asarray(x).transpose((1, 2, 0)) for x in batch] # Perform inference results = self.model(batch, **self.kwargs) # Restructure results to conform to MAITE all_detections = [] for result in results: detections = result.boxes if detections is None: continue detections = detections.cpu().numpy() boxes = np.asarray(detections.xyxy) labels = np.asarray(detections.cls, dtype=np.uint8) scores = np.asarray(detections.conf) all_detections.append( ObjectDetectionTargetImpl(boxes=boxes, labels=labels, scores=scores) ) return all_detections wrapped_yolov8_model: od.Model = WrappedYOLO(yolov8_model, id="Ultralytics.YOLOv8", verbose=False) We can now visualize the wrapped model’s output for both images. .. code:: ipython3 wrapped_yolo_preds = wrapped_yolov8_model(imgs) wrapped_yolo_preds .. parsed-literal:: [ObjectDetectionTargetImpl(boxes=array([[ 9.0822, 20.553, 517, 604.41], [ 218.77, 414.25, 300.06, 538.04], [ 199.25, 414.16, 300.59, 605.61]], dtype=float32), labels=array([ 0, 27, 27], dtype=uint8), scores=array([ 0.89559, 0.60285, 0.47824], dtype=float32)), ObjectDetectionTargetImpl(boxes=array([[ 50.009, 397.59, 243.39, 905.03], [ 223.38, 405.18, 344.82, 857.35], [ 670.62, 378.47, 809.87, 875.35], [ 29.816, 228.97, 797.11, 751.33], [ 0.15492, 550.5, 61.616, 870.62]], dtype=float32), labels=array([0, 0, 0, 5, 0], dtype=uint8), scores=array([ 0.87989, 0.87631, 0.86818, 0.84481, 0.44715], dtype=float32))] .. code:: ipython3 render_wrapped_results(imgs, wrapped_yolo_preds, wrapped_yolov8_model.metadata) .. image:: wrap_object_detection_model_files/wrap_object_detection_model_27_0.png .. image:: wrap_object_detection_model_files/wrap_object_detection_model_27_1.png .. code:: ipython3 # Get predictions from both models for first image in batch wrapped_rcnn_pred = wrapped_rcnn_preds[0] wrapped_yolo_pred = wrapped_yolo_preds[0] wrapped_rcnn_fields = vars(wrapped_rcnn_pred).keys() wrapped_yolo_fields = vars(wrapped_yolo_pred).keys() print( f""" Wrapped R-CNN Outputs ===================== Result Type: {type(wrapped_rcnn_pred)} Result Attributes: {wrapped_rcnn_fields} YOLO Outputs ============ Result Type: {type(wrapped_yolo_pred)} Result Attributes: {wrapped_yolo_fields} """ ) .. parsed-literal:: Wrapped R-CNN Outputs ===================== Result Type: Result Attributes: dict_keys(['boxes', 'labels', 'scores']) YOLO Outputs ============ Result Type: Result Attributes: dict_keys(['boxes', 'labels', 'scores']) Notice that by wrapping both models to be MAITE-compliant, we were able to: - Use the same input data, ``imgs``, for both models. - Create standardized output, conforming to ``ObjectDetectionTarget``, from both models. - Render the results of both models using the same function, due to the standardized inputs and outputs. 4 Summary --------- Wrapping object detection models with MAITE ensures interoperability and simplifies integration with additional T&E processes. By standardizing inputs and outputs, you can create consistent workflows that work seamlessly across models, like Faster R-CNN and YOLOv8. The key to model wrapping is to define the following: - A ``__call__`` method that receives a ``maite.protocols.object_detection.InputBatchType`` as input and returns ``maite.protocols.object_detection.TargetBatchType``. - A ``metadata`` attribute that’s a typed dictionary with at least an “id field.