:: Copyright 2025, MASSACHUSETTS INSTITUTE OF TECHNOLOGY Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014). SPDX-License-Identifier: MIT Wrap an Object Detection Dataset ================================ In this how-to, we will show you how to wrap an object detection dataset as a `maite.protocols.object_detection.Dataset <../generated/maite.protocols.object_detection.Dataset.html>`__. The general steps for wrapping a dataset are: - Understand the source (native) dataset - Create a wrapper that makes the source dataset conform to the MAITE dataset protocol (interface) - Verify that the wrapped dataset works correctly Note: - Implementors do not need to have their class inherit from a base class like ``maite.protocols.object_detection.Dataset`` (although they can) since MAITE uses Python protocols for defining interfaces. This note applies to other MAITE protocols as well, e.g., ``Model`` and ``Metric``. See the `Primer on Python Typing <../explanation/type_hints_for_API_design.html>`__ for more background on protocols and structural subtyping. - There are multiple ways of implementing a ``Dataset`` class. Implementors can start from scratch or leverge dataset utilities provided by other libraries such as Torchvision or Hugging Face. This how-to guide will be using Torchvision’s `datasets.load_dataset.CocoDetection `__ class to load images and COCO-formatted annotations from a directory. 1 Load a subset of the COCO dataset with Torchvision ---------------------------------------------------- First we load the necessary packages for running this notebook and display their versions. .. code:: ipython3 import json import matplotlib.pyplot as plt import pprint import requests import torch from dataclasses import dataclass from maite.protocols import DatasetMetadata, DatumMetadata, object_detection as od from pathlib import Path from torchvision.datasets import CocoDetection from torchvision.ops.boxes import box_convert from torchvision.transforms.functional import pil_to_tensor from typing import Any %matplotlib inline %load_ext watermark %watermark -iv -v .. parsed-literal:: Python implementation: CPython Python version : 3.9.21 IPython version : 8.18.1 maite : 0.7.3 torchvision: 0.21.0 json : 2.0.9 requests : 2.32.3 torch : 2.6.0 matplotlib : 3.9.4 In order to make this how-to guide faster to run and not require a large download, we’ll work with a subset of the `Common Objects in Context (COCO) `__ dataset. We’ve modified the annotations JSON file from the validation split of the COCO 2017 Object Detection Task to contain only the first 4 images (and will dynamically download only those images using the “coco_url”). Note that the COCO annotations are licensed under a `Creative Commons Attribution 4.0 License `__ (see COCO `terms of use `__). We use the following function and code to download the images: .. code:: ipython3 def download_images(coco_json_subset: dict[str, Any], root: Path): """Download a subset of COCO images. Parameters ---------- coco_json_subset : dict[str, Any] COCO val2017_first4 JSON file. root : Path Location of COCO data. """ root.mkdir(parents=True, exist_ok=True) for image in coco_json_subset["images"]: url = image["coco_url"] filename = Path(root) / image["file_name"] if filename.exists(): print(f"skipping {url}") else: print(f"saving {url} to {filename} ... ", end="") r = requests.get(url) with open(filename, "wb") as f: f.write(r.content) print(f"done") .. code:: ipython3 COCO_ROOT = Path("../sample_data/coco/coco_val2017_subset") coco_subset_json = dict() ann_subset_file = COCO_ROOT / "instances_val2017_first4.json" coco_subset_json = json.load(open(ann_subset_file, "r")) download_images(coco_subset_json, root=COCO_ROOT) .. parsed-literal:: saving http://images.cocodataset.org/val2017/000000397133.jpg to ../sample_data/coco/coco_val2017_subset/000000397133.jpg ... .. parsed-literal:: done saving http://images.cocodataset.org/val2017/000000037777.jpg to ../sample_data/coco/coco_val2017_subset/000000037777.jpg ... .. parsed-literal:: done saving http://images.cocodataset.org/val2017/000000252219.jpg to ../sample_data/coco/coco_val2017_subset/000000252219.jpg ... .. parsed-literal:: done saving http://images.cocodataset.org/val2017/000000087038.jpg to ../sample_data/coco/coco_val2017_subset/000000087038.jpg ... .. parsed-literal:: done Next we use Torchvision’s ``CocoDetection`` dataset class to load our COCO subset located under ``maite/examples/sample_data/coco/coco_val2017_subset``. The keyword arguments of ``CocoDetection`` are: - ``root``: the location of the folder containing the images - ``annFile``: the location of the annotation JSON file Please see Torchvision’s `documentation `__ for more information on how to use the ``CocoDetection`` class. .. code:: ipython3 tv_dataset = CocoDetection( root=str(COCO_ROOT), annFile=str(ann_subset_file), ) print(f"\n{len(tv_dataset) = }") .. parsed-literal:: loading annotations into memory... Done (t=0.00s) creating index... index created! len(tv_dataset) = 4 2 Examine the source dataset ---------------------------- One datum from ``tv_dataset`` includes an image and its corresponding annotations. Let’s get one datum and inspect it further. .. code:: ipython3 img, annotations = tv_dataset[0] # `img` is a PIL Image plt.imshow(img) plt.axis("off") # Optional: Hide the axes plt.show() print(f"{type(img) = }") # `annotations` is a list of dictionaries (corresponding to 14 objects in this case) print(f"{len(annotations) = }") # Each annotation dictionary contains the object's bounding box (bbox) plus other info print(f"{annotations[0].keys() = }") # Note that the COCO bounding box format is [x_min, y_min, width, height] print(f"{annotations[0]['bbox'] = }") # Class/category labels and ids are available via the `cats` attribute on the dataset itself print(f"{tv_dataset.coco.cats[1] = }") .. image:: wrap_object_detection_dataset_files/wrap_object_detection_dataset_10_0.png .. parsed-literal:: type(img) = len(annotations) = 14 annotations[0].keys() = dict_keys(['segmentation', 'area', 'iscrowd', 'image_id', 'bbox', 'category_id', 'id']) annotations[0]['bbox'] = [102.49, 118.47, 7.9, 17.31] tv_dataset.coco.cats[1] = {'supercategory': 'person', 'id': 1, 'name': 'person'} 3 Create a MAITE wrapper for the source dataset ----------------------------------------------- So far we’ve used Torchvision’s ``CocoDetection`` class to load and explore the dataset. It doesn’t comply with ``maite.protocols.object_detection.Dataset``; however, it can be used as a starting point to construct a MAITE-compliant object detection dataset. A class implementing ``maite.protocols.object_detection.Dataset`` needs the following two methods: - ``__getitem__(index: int)``: returns a tuple containing a datum’s input image, ground truth detections, and metadata. - ``__len__()``: returns the number of data elements in the dataset. and an attribute: - ``metadata``: a typed dictionary containing an ``id`` string and an optional (but recommended) map from class indexes to labels. The datum returned by ``__getitem__`` is of type ``tuple[InputType, ObjectDetectionTarget, DatumMetadata]``. - MAITE’s ``InputType`` is an ``ArrayLike`` (e.g., ``numpy.ndarray`` or ``torch.Tensor``) with shape ``(C, H, W)`` representing the channel, height, and width of an image. Since ``tv_dataset`` has images in the PIL image format, we need to convert them to a type compatible with ``ArrayLike``, which we’ll do using the ``pil_to_tensor`` function. - MAITE’s ``ObjectDetectionTarget`` protocol can be implemenented by defining a ``dataclass``, ``CocoDetectionTarget``, which encodes the ground-truth labels of the object detection task, namely: - ``boxes``: a shape-(N_DETECTIONS, 4) ``ArrayLike``, where N_DETECTIONS is the number of detections (objects) in the current single image, and each row is a bounding box in ``x0, y0, x1, y1`` format - ``labels``: a shape-(N_DETECTIONS,) ``ArrayLike`` containing the integer label associated with each detection - ``scores``: a shape-(N_DETECTIONS,) ``ArrayLike`` containing the score (confidence) associated with each detection. For a dataset’s ground truth (as opposed to model predictions), ``scores`` are always 1. - MAITE’s ``DatumMetadata`` is a ``TypedDict`` containing metadata associated with a single datum. An ``id`` of type ``int`` or ``str`` is required. Putting everything together, below is an implementation of the ``maite.protocols.object_detection.Dataset`` protocol called ``CocoDetectionDataset``. .. code:: ipython3 @dataclass class CocoDetectionTarget: boxes: torch.Tensor labels: torch.Tensor scores: torch.Tensor class CocoDetectionDataset: def __init__(self, dataset: CocoDetection, id: str): self._dataset = dataset # Get mapping from COCO category to name index2label = {k: v["name"] for k, v in dataset.coco.cats.items()} # Add dataset-level metadata attribute, with required id and optional index2label mapping self.metadata: DatasetMetadata = { "id": id, "index2label": index2label, } def __len__(self) -> int: return len(self._dataset) def __getitem__( self, index: int ) -> tuple[torch.Tensor, CocoDetectionTarget, DatumMetadata]: # Get original data item img_pil, annotations = self._dataset[index] # Format input img_pt = pil_to_tensor(img_pil) # Format ground truth num_boxes = len(annotations) boxes = torch.zeros(num_boxes, 4) for i, ann in enumerate(annotations): bbox = torch.as_tensor(ann["bbox"]) boxes[i, :] = box_convert(bbox, in_fmt="xywh", out_fmt="xyxy") labels = torch.as_tensor([ann["category_id"] for ann in annotations]) scores = torch.ones(num_boxes) # Format metadata datum_metadata: DatumMetadata = { "id": self._dataset.ids[index], } return img_pt, CocoDetectionTarget(boxes, labels, scores), datum_metadata .. code:: ipython3 coco_dataset: od.Dataset = CocoDetectionDataset(tv_dataset, "COCO Detection Subset") We have now created a dataset class that conforms to the ``maite.protocols.object_detection.Dataset`` protocol, and wrapped the native dataset with it. Here the ``coco_dataset`` variable has ``od.Dataset`` as the type hint. If your development environment has a static type checker like Pyright enabled (see the `Enable Static Type Checking <./static_typing.html>`__ guide), then the type checker will verify that our wrapped dataset conforms to the protocol and indicate a problem if not. 4 Examine the MAITE-wrapped dataset ----------------------------------- Now let’s inspect MAITE-compliant object detection dataset to verify that it’s behaving as expected. First we verify that the length is 4: .. code:: ipython3 print(f"{len(coco_dataset) = }") .. parsed-literal:: len(coco_dataset) = 4 We display the dataset-level metadata: .. code:: ipython3 pprint.pp(coco_dataset.metadata) .. parsed-literal:: {'id': 'COCO Detection Subset', 'index2label': {1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus', 7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant', 13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear', 24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag', 32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard', 37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove', 41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle', 46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon', 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange', 56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut', 61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed', 67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse', 75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven', 80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book', 85: 'clock', 86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush'}} Finally we inspect a single datum: .. code:: ipython3 # Get first datum image, target, datum_metadata = coco_dataset[0] # Display image plt.imshow(img) plt.axis("off") # Optional: Hide the axes plt.show() # Bridge/convert ArrayLike's to PyTorch tensors image = torch.as_tensor(image) boxes = torch.as_tensor(target.boxes) labels = torch.as_tensor(target.labels) scores = torch.as_tensor(target.scores) # Print shapes print(f"{image.shape = }") # image has height 230 and weight 352 print(f"{boxes.shape = }") # there are 14 bounding boxes print(f"{labels.shape = }") print(f"{scores.shape = }") # Print datum-level metadata print(f"{datum_metadata = }") # this datum corresponds to image file 000000037777.jpg .. image:: wrap_object_detection_dataset_files/wrap_object_detection_dataset_20_0.png .. parsed-literal:: image.shape = torch.Size([3, 230, 352]) boxes.shape = torch.Size([14, 4]) labels.shape = torch.Size([14]) scores.shape = torch.Size([14]) datum_metadata = {'id': 37777} 5 Conclusion ------------ In this how-to guide, we demonstrated how to wrap an object detection dataset to be MAITE compliant. Key implementation steps included: - Defining ``__len__`` and ``__getitem__`` methods - Choosing a specific image ``InputType`` (e.g., ``numpy.ndarray`` or ``torch.Tensor``) - Creating a target dataclass that conforms to the ``ObjectDetectionTarget`` protocol - Setting up ``DatasetMetadata`` and ``DatumMetadata`` More dataset protocol details can be found here: `maite.protocols.object_detection.Dataset <../generated/maite.protocols.object_detection.Dataset.html>`__.