Copyright 2025, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
SPDX-License-Identifier: MIT

Wrap an Object Detection Dataset#

In this how-to, we will show you how to wrap an object detection dataset as a maite.protocols.object_detection.Dataset. The general steps for wrapping a dataset are:

  • Understand the source (native) dataset

  • Create a wrapper that makes the source dataset conform to the MAITE dataset protocol (interface)

  • Verify that the wrapped dataset works correctly

Note:

  • Implementors do not need to have their class inherit from a base class like maite.protocols.object_detection.Dataset (although they can) since MAITE uses Python protocols for defining interfaces. This note applies to other MAITE protocols as well, e.g., Model and Metric. See the Primer on Python Typing for more background on protocols and structural subtyping.

  • There are multiple ways of implementing a Dataset class. Implementors can start from scratch or leverge dataset utilities provided by other libraries such as Torchvision or Hugging Face.

This how-to guide will be using Torchvision’s datasets.load_dataset.CocoDetection class to load images and COCO-formatted annotations from a directory.

1 Load a subset of the COCO dataset with Torchvision#

First we load the necessary packages for running this notebook and display their versions.

import json
import matplotlib.pyplot as plt
import pprint
import requests
import torch

from dataclasses import dataclass
from maite.protocols import DatasetMetadata, DatumMetadata, object_detection as od
from pathlib import Path
from torchvision.datasets import CocoDetection
from torchvision.ops.boxes import box_convert
from torchvision.transforms.functional import pil_to_tensor
from typing import Any

%matplotlib inline

%load_ext watermark
%watermark -iv -v
Python implementation: CPython
Python version       : 3.9.21
IPython version      : 8.18.1

maite      : 0.7.3
torchvision: 0.21.0
json       : 2.0.9
requests   : 2.32.3
torch      : 2.6.0
matplotlib : 3.9.4

In order to make this how-to guide faster to run and not require a large download, we’ll work with a subset of the Common Objects in Context (COCO) dataset. We’ve modified the annotations JSON file from the validation split of the COCO 2017 Object Detection Task to contain only the first 4 images (and will dynamically download only those images using the “coco_url”).

Note that the COCO annotations are licensed under a Creative Commons Attribution 4.0 License (see COCO terms of use).

We use the following function and code to download the images:

def download_images(coco_json_subset: dict[str, Any], root: Path):
    """Download a subset of COCO images.

    Parameters
    ----------
    coco_json_subset : dict[str, Any]
        COCO val2017_first4 JSON file.
    root : Path
        Location of COCO data.
    """
    root.mkdir(parents=True, exist_ok=True)

    for image in coco_json_subset["images"]:
        url = image["coco_url"]
        filename = Path(root) / image["file_name"]
        if filename.exists():
            print(f"skipping {url}")
        else:
            print(f"saving {url} to {filename} ... ", end="")
            r = requests.get(url)
            with open(filename, "wb") as f:
                f.write(r.content)
            print(f"done")
COCO_ROOT = Path("../sample_data/coco/coco_val2017_subset")
coco_subset_json = dict()
ann_subset_file = COCO_ROOT / "instances_val2017_first4.json"
coco_subset_json = json.load(open(ann_subset_file, "r"))
download_images(coco_subset_json, root=COCO_ROOT)
saving http://images.cocodataset.org/val2017/000000397133.jpg to ../sample_data/coco/coco_val2017_subset/000000397133.jpg ...
done
saving http://images.cocodataset.org/val2017/000000037777.jpg to ../sample_data/coco/coco_val2017_subset/000000037777.jpg ...
done
saving http://images.cocodataset.org/val2017/000000252219.jpg to ../sample_data/coco/coco_val2017_subset/000000252219.jpg ...
done
saving http://images.cocodataset.org/val2017/000000087038.jpg to ../sample_data/coco/coco_val2017_subset/000000087038.jpg ...
done

Next we use Torchvision’s CocoDetection dataset class to load our COCO subset located under maite/examples/sample_data/coco/coco_val2017_subset. The keyword arguments of CocoDetection are:

  • root: the location of the folder containing the images

  • annFile: the location of the annotation JSON file

Please see Torchvision’s documentation for more information on how to use the CocoDetection class.

tv_dataset = CocoDetection(
    root=str(COCO_ROOT),
    annFile=str(ann_subset_file),
)

print(f"\n{len(tv_dataset) = }")
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!

len(tv_dataset) = 4

2 Examine the source dataset#

One datum from tv_dataset includes an image and its corresponding annotations. Let’s get one datum and inspect it further.

img, annotations = tv_dataset[0]

# `img` is a PIL Image
plt.imshow(img)
plt.axis("off")  # Optional: Hide the axes
plt.show()
print(f"{type(img) = }")

# `annotations` is a list of dictionaries (corresponding to 14 objects in this case)
print(f"{len(annotations) = }")

# Each annotation dictionary contains the object's bounding box (bbox) plus other info
print(f"{annotations[0].keys() = }")

# Note that the COCO bounding box format is [x_min, y_min, width, height]
print(f"{annotations[0]['bbox'] = }")

# Class/category labels and ids are available via the `cats` attribute on the dataset itself
print(f"{tv_dataset.coco.cats[1] = }")
../_images/wrap_object_detection_dataset_10_0.png
type(img) = <class 'PIL.Image.Image'>
len(annotations) = 14
annotations[0].keys() = dict_keys(['segmentation', 'area', 'iscrowd', 'image_id', 'bbox', 'category_id', 'id'])
annotations[0]['bbox'] = [102.49, 118.47, 7.9, 17.31]
tv_dataset.coco.cats[1] = {'supercategory': 'person', 'id': 1, 'name': 'person'}

3 Create a MAITE wrapper for the source dataset#

So far we’ve used Torchvision’s CocoDetection class to load and explore the dataset. It doesn’t comply with maite.protocols.object_detection.Dataset; however, it can be used as a starting point to construct a MAITE-compliant object detection dataset.

A class implementing maite.protocols.object_detection.Dataset needs the following two methods:

  • __getitem__(index: int): returns a tuple containing a datum’s input image, ground truth detections, and metadata.

  • __len__(): returns the number of data elements in the dataset.

and an attribute:

  • metadata: a typed dictionary containing an id string and an optional (but recommended) map from class indexes to labels.

The datum returned by __getitem__ is of type tuple[InputType, ObjectDetectionTarget, DatumMetadata].

  • MAITE’s InputType is an ArrayLike (e.g., numpy.ndarray or torch.Tensor) with shape (C, H, W) representing the channel, height, and width of an image. Since tv_dataset has images in the PIL image format, we need to convert them to a type compatible with ArrayLike, which we’ll do using the pil_to_tensor function.

  • MAITE’s ObjectDetectionTarget protocol can be implemenented by defining a dataclass, CocoDetectionTarget, which encodes the ground-truth labels of the object detection task, namely:

    • boxes: a shape-(N_DETECTIONS, 4) ArrayLike, where N_DETECTIONS is the number of detections (objects) in the current single image, and each row is a bounding box in x0, y0, x1, y1 format

    • labels: a shape-(N_DETECTIONS,) ArrayLike containing the integer label associated with each detection

    • scores: a shape-(N_DETECTIONS,) ArrayLike containing the score (confidence) associated with each detection. For a dataset’s ground truth (as opposed to model predictions), scores are always 1.

  • MAITE’s DatumMetadata is a TypedDict containing metadata associated with a single datum. An id of type int or str is required.

Putting everything together, below is an implementation of the maite.protocols.object_detection.Dataset protocol called CocoDetectionDataset.

@dataclass
class CocoDetectionTarget:
    boxes: torch.Tensor
    labels: torch.Tensor
    scores: torch.Tensor

class CocoDetectionDataset:
    def __init__(self, dataset: CocoDetection, id: str):
        self._dataset = dataset

        # Get mapping from COCO category to name
        index2label = {k: v["name"] for k, v in dataset.coco.cats.items()}

        # Add dataset-level metadata attribute, with required id and optional index2label mapping
        self.metadata: DatasetMetadata = {
            "id": id,
            "index2label": index2label,
        }

    def __len__(self) -> int:
        return len(self._dataset)

    def __getitem__(
        self, index: int
    ) -> tuple[torch.Tensor, CocoDetectionTarget, DatumMetadata]:

        # Get original data item
        img_pil, annotations = self._dataset[index]

        # Format input
        img_pt = pil_to_tensor(img_pil)

        # Format ground truth
        num_boxes = len(annotations)
        boxes = torch.zeros(num_boxes, 4)
        for i, ann in enumerate(annotations):
            bbox = torch.as_tensor(ann["bbox"])
            boxes[i, :] = box_convert(bbox, in_fmt="xywh", out_fmt="xyxy")

        labels = torch.as_tensor([ann["category_id"] for ann in annotations])
        scores = torch.ones(num_boxes)

        # Format metadata
        datum_metadata: DatumMetadata = {
            "id": self._dataset.ids[index],
        }

        return img_pt, CocoDetectionTarget(boxes, labels, scores), datum_metadata
coco_dataset: od.Dataset = CocoDetectionDataset(tv_dataset, "COCO Detection Subset")

We have now created a dataset class that conforms to the maite.protocols.object_detection.Dataset protocol, and wrapped the native dataset with it.

Here the coco_dataset variable has od.Dataset as the type hint. If your development environment has a static type checker like Pyright enabled (see the Enable Static Type Checking guide), then the type checker will verify that our wrapped dataset conforms to the protocol and indicate a problem if not.

4 Examine the MAITE-wrapped dataset#

Now let’s inspect MAITE-compliant object detection dataset to verify that it’s behaving as expected.

First we verify that the length is 4:

print(f"{len(coco_dataset) = }")
len(coco_dataset) = 4

We display the dataset-level metadata:

pprint.pp(coco_dataset.metadata)
{'id': 'COCO Detection Subset',
 'index2label': {1: 'person',
                 2: 'bicycle',
                 3: 'car',
                 4: 'motorcycle',
                 5: 'airplane',
                 6: 'bus',
                 7: 'train',
                 8: 'truck',
                 9: 'boat',
                 10: 'traffic light',
                 11: 'fire hydrant',
                 13: 'stop sign',
                 14: 'parking meter',
                 15: 'bench',
                 16: 'bird',
                 17: 'cat',
                 18: 'dog',
                 19: 'horse',
                 20: 'sheep',
                 21: 'cow',
                 22: 'elephant',
                 23: 'bear',
                 24: 'zebra',
                 25: 'giraffe',
                 27: 'backpack',
                 28: 'umbrella',
                 31: 'handbag',
                 32: 'tie',
                 33: 'suitcase',
                 34: 'frisbee',
                 35: 'skis',
                 36: 'snowboard',
                 37: 'sports ball',
                 38: 'kite',
                 39: 'baseball bat',
                 40: 'baseball glove',
                 41: 'skateboard',
                 42: 'surfboard',
                 43: 'tennis racket',
                 44: 'bottle',
                 46: 'wine glass',
                 47: 'cup',
                 48: 'fork',
                 49: 'knife',
                 50: 'spoon',
                 51: 'bowl',
                 52: 'banana',
                 53: 'apple',
                 54: 'sandwich',
                 55: 'orange',
                 56: 'broccoli',
                 57: 'carrot',
                 58: 'hot dog',
                 59: 'pizza',
                 60: 'donut',
                 61: 'cake',
                 62: 'chair',
                 63: 'couch',
                 64: 'potted plant',
                 65: 'bed',
                 67: 'dining table',
                 70: 'toilet',
                 72: 'tv',
                 73: 'laptop',
                 74: 'mouse',
                 75: 'remote',
                 76: 'keyboard',
                 77: 'cell phone',
                 78: 'microwave',
                 79: 'oven',
                 80: 'toaster',
                 81: 'sink',
                 82: 'refrigerator',
                 84: 'book',
                 85: 'clock',
                 86: 'vase',
                 87: 'scissors',
                 88: 'teddy bear',
                 89: 'hair drier',
                 90: 'toothbrush'}}

Finally we inspect a single datum:

# Get first datum
image, target, datum_metadata = coco_dataset[0]

# Display image
plt.imshow(img)
plt.axis("off")  # Optional: Hide the axes
plt.show()

# Bridge/convert ArrayLike's to PyTorch tensors
image = torch.as_tensor(image)
boxes = torch.as_tensor(target.boxes)
labels = torch.as_tensor(target.labels)
scores = torch.as_tensor(target.scores)

# Print shapes
print(f"{image.shape = }")  # image has height 230 and weight 352
print(f"{boxes.shape = }")  # there are 14 bounding boxes
print(f"{labels.shape = }")
print(f"{scores.shape = }")

# Print datum-level metadata
print(f"{datum_metadata = }")  # this datum corresponds to image file 000000037777.jpg
../_images/wrap_object_detection_dataset_20_0.png
image.shape = torch.Size([3, 230, 352])
boxes.shape = torch.Size([14, 4])
labels.shape = torch.Size([14])
scores.shape = torch.Size([14])
datum_metadata = {'id': 37777}

5 Conclusion#

In this how-to guide, we demonstrated how to wrap an object detection dataset to be MAITE compliant. Key implementation steps included:

  • Defining __len__ and __getitem__ methods

  • Choosing a specific image InputType (e.g., numpy.ndarray or torch.Tensor)

  • Creating a target dataclass that conforms to the ObjectDetectionTarget protocol

  • Setting up DatasetMetadata and DatumMetadata

More dataset protocol details can be found here: maite.protocols.object_detection.Dataset.