Copyright 2025, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
SPDX-License-Identifier: MIT
Wrap an Object Detection Dataset#
In this how-to, we will show you how to wrap an object detection dataset as a maite.protocols.object_detection.Dataset. The general steps for wrapping a dataset are:
Understand the source (native) dataset
Create a wrapper that makes the source dataset conform to the MAITE dataset protocol (interface)
Verify that the wrapped dataset works correctly
Note:
Implementors do not need to have their class inherit from a base class like
maite.protocols.object_detection.Dataset
(although they can) since MAITE uses Python protocols for defining interfaces. This note applies to other MAITE protocols as well, e.g.,Model
andMetric
. See the Primer on Python Typing for more background on protocols and structural subtyping.There are multiple ways of implementing a
Dataset
class. Implementors can start from scratch or leverge dataset utilities provided by other libraries such as Torchvision or Hugging Face.
This how-to guide will be using Torchvision’s datasets.load_dataset.CocoDetection class to load images and COCO-formatted annotations from a directory.
1 Load a subset of the COCO dataset with Torchvision#
First we load the necessary packages for running this notebook and display their versions.
import json
import matplotlib.pyplot as plt
import pprint
import requests
import torch
from dataclasses import dataclass
from maite.protocols import DatasetMetadata, DatumMetadata, object_detection as od
from pathlib import Path
from torchvision.datasets import CocoDetection
from torchvision.ops.boxes import box_convert
from torchvision.transforms.functional import pil_to_tensor
from typing import Any
%matplotlib inline
%load_ext watermark
%watermark -iv -v
Python implementation: CPython
Python version : 3.9.21
IPython version : 8.18.1
maite : 0.7.3
torchvision: 0.21.0
json : 2.0.9
requests : 2.32.3
torch : 2.6.0
matplotlib : 3.9.4
In order to make this how-to guide faster to run and not require a large download, we’ll work with a subset of the Common Objects in Context (COCO) dataset. We’ve modified the annotations JSON file from the validation split of the COCO 2017 Object Detection Task to contain only the first 4 images (and will dynamically download only those images using the “coco_url”).
Note that the COCO annotations are licensed under a Creative Commons Attribution 4.0 License (see COCO terms of use).
We use the following function and code to download the images:
def download_images(coco_json_subset: dict[str, Any], root: Path):
"""Download a subset of COCO images.
Parameters
----------
coco_json_subset : dict[str, Any]
COCO val2017_first4 JSON file.
root : Path
Location of COCO data.
"""
root.mkdir(parents=True, exist_ok=True)
for image in coco_json_subset["images"]:
url = image["coco_url"]
filename = Path(root) / image["file_name"]
if filename.exists():
print(f"skipping {url}")
else:
print(f"saving {url} to {filename} ... ", end="")
r = requests.get(url)
with open(filename, "wb") as f:
f.write(r.content)
print(f"done")
saving http://images.cocodataset.org/val2017/000000397133.jpg to ../sample_data/coco/coco_val2017_subset/000000397133.jpg ...
done
saving http://images.cocodataset.org/val2017/000000037777.jpg to ../sample_data/coco/coco_val2017_subset/000000037777.jpg ...
done
saving http://images.cocodataset.org/val2017/000000252219.jpg to ../sample_data/coco/coco_val2017_subset/000000252219.jpg ...
done
saving http://images.cocodataset.org/val2017/000000087038.jpg to ../sample_data/coco/coco_val2017_subset/000000087038.jpg ...
done
Next we use Torchvision’s CocoDetection
dataset class to load our
COCO subset located under
maite/examples/sample_data/coco/coco_val2017_subset
. The keyword
arguments of CocoDetection
are:
root
: the location of the folder containing the imagesannFile
: the location of the annotation JSON file
Please see Torchvision’s
documentation
for more information on how to use the CocoDetection
class.
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
len(tv_dataset) = 4
2 Examine the source dataset#
One datum from tv_dataset
includes an image and its corresponding
annotations. Let’s get one datum and inspect it further.
img, annotations = tv_dataset[0]
# `img` is a PIL Image
plt.imshow(img)
plt.axis("off") # Optional: Hide the axes
plt.show()
print(f"{type(img) = }")
# `annotations` is a list of dictionaries (corresponding to 14 objects in this case)
print(f"{len(annotations) = }")
# Each annotation dictionary contains the object's bounding box (bbox) plus other info
print(f"{annotations[0].keys() = }")
# Note that the COCO bounding box format is [x_min, y_min, width, height]
print(f"{annotations[0]['bbox'] = }")
# Class/category labels and ids are available via the `cats` attribute on the dataset itself
print(f"{tv_dataset.coco.cats[1] = }")

type(img) = <class 'PIL.Image.Image'>
len(annotations) = 14
annotations[0].keys() = dict_keys(['segmentation', 'area', 'iscrowd', 'image_id', 'bbox', 'category_id', 'id'])
annotations[0]['bbox'] = [102.49, 118.47, 7.9, 17.31]
tv_dataset.coco.cats[1] = {'supercategory': 'person', 'id': 1, 'name': 'person'}
3 Create a MAITE wrapper for the source dataset#
So far we’ve used Torchvision’s CocoDetection
class to load and
explore the dataset. It doesn’t comply with
maite.protocols.object_detection.Dataset
; however, it can be used as
a starting point to construct a MAITE-compliant object detection
dataset.
A class implementing maite.protocols.object_detection.Dataset
needs
the following two methods:
__getitem__(index: int)
: returns a tuple containing a datum’s input image, ground truth detections, and metadata.__len__()
: returns the number of data elements in the dataset.
and an attribute:
metadata
: a typed dictionary containing anid
string and an optional (but recommended) map from class indexes to labels.
The datum returned by __getitem__
is of type
tuple[InputType, ObjectDetectionTarget, DatumMetadata]
.
MAITE’s
InputType
is anArrayLike
(e.g.,numpy.ndarray
ortorch.Tensor
) with shape(C, H, W)
representing the channel, height, and width of an image. Sincetv_dataset
has images in the PIL image format, we need to convert them to a type compatible withArrayLike
, which we’ll do using thepil_to_tensor
function.MAITE’s
ObjectDetectionTarget
protocol can be implemenented by defining adataclass
,CocoDetectionTarget
, which encodes the ground-truth labels of the object detection task, namely:boxes
: a shape-(N_DETECTIONS, 4)ArrayLike
, where N_DETECTIONS is the number of detections (objects) in the current single image, and each row is a bounding box inx0, y0, x1, y1
formatlabels
: a shape-(N_DETECTIONS,)ArrayLike
containing the integer label associated with each detectionscores
: a shape-(N_DETECTIONS,)ArrayLike
containing the score (confidence) associated with each detection. For a dataset’s ground truth (as opposed to model predictions),scores
are always 1.
MAITE’s
DatumMetadata
is aTypedDict
containing metadata associated with a single datum. Anid
of typeint
orstr
is required.
Putting everything together, below is an implementation of the
maite.protocols.object_detection.Dataset
protocol called
CocoDetectionDataset
.
@dataclass
class CocoDetectionTarget:
boxes: torch.Tensor
labels: torch.Tensor
scores: torch.Tensor
class CocoDetectionDataset:
def __init__(self, dataset: CocoDetection, id: str):
self._dataset = dataset
# Get mapping from COCO category to name
index2label = {k: v["name"] for k, v in dataset.coco.cats.items()}
# Add dataset-level metadata attribute, with required id and optional index2label mapping
self.metadata: DatasetMetadata = {
"id": id,
"index2label": index2label,
}
def __len__(self) -> int:
return len(self._dataset)
def __getitem__(
self, index: int
) -> tuple[torch.Tensor, CocoDetectionTarget, DatumMetadata]:
# Get original data item
img_pil, annotations = self._dataset[index]
# Format input
img_pt = pil_to_tensor(img_pil)
# Format ground truth
num_boxes = len(annotations)
boxes = torch.zeros(num_boxes, 4)
for i, ann in enumerate(annotations):
bbox = torch.as_tensor(ann["bbox"])
boxes[i, :] = box_convert(bbox, in_fmt="xywh", out_fmt="xyxy")
labels = torch.as_tensor([ann["category_id"] for ann in annotations])
scores = torch.ones(num_boxes)
# Format metadata
datum_metadata: DatumMetadata = {
"id": self._dataset.ids[index],
}
return img_pt, CocoDetectionTarget(boxes, labels, scores), datum_metadata
coco_dataset: od.Dataset = CocoDetectionDataset(tv_dataset, "COCO Detection Subset")
We have now created a dataset class that conforms to the
maite.protocols.object_detection.Dataset
protocol, and wrapped the
native dataset with it.
Here the coco_dataset
variable has od.Dataset
as the type hint.
If your development environment has a static type checker like Pyright
enabled (see the Enable Static Type Checking
guide), then the type checker will verify that our wrapped dataset
conforms to the protocol and indicate a problem if not.
4 Examine the MAITE-wrapped dataset#
Now let’s inspect MAITE-compliant object detection dataset to verify that it’s behaving as expected.
First we verify that the length is 4:
len(coco_dataset) = 4
We display the dataset-level metadata:
pprint.pp(coco_dataset.metadata)
{'id': 'COCO Detection Subset',
'index2label': {1: 'person',
2: 'bicycle',
3: 'car',
4: 'motorcycle',
5: 'airplane',
6: 'bus',
7: 'train',
8: 'truck',
9: 'boat',
10: 'traffic light',
11: 'fire hydrant',
13: 'stop sign',
14: 'parking meter',
15: 'bench',
16: 'bird',
17: 'cat',
18: 'dog',
19: 'horse',
20: 'sheep',
21: 'cow',
22: 'elephant',
23: 'bear',
24: 'zebra',
25: 'giraffe',
27: 'backpack',
28: 'umbrella',
31: 'handbag',
32: 'tie',
33: 'suitcase',
34: 'frisbee',
35: 'skis',
36: 'snowboard',
37: 'sports ball',
38: 'kite',
39: 'baseball bat',
40: 'baseball glove',
41: 'skateboard',
42: 'surfboard',
43: 'tennis racket',
44: 'bottle',
46: 'wine glass',
47: 'cup',
48: 'fork',
49: 'knife',
50: 'spoon',
51: 'bowl',
52: 'banana',
53: 'apple',
54: 'sandwich',
55: 'orange',
56: 'broccoli',
57: 'carrot',
58: 'hot dog',
59: 'pizza',
60: 'donut',
61: 'cake',
62: 'chair',
63: 'couch',
64: 'potted plant',
65: 'bed',
67: 'dining table',
70: 'toilet',
72: 'tv',
73: 'laptop',
74: 'mouse',
75: 'remote',
76: 'keyboard',
77: 'cell phone',
78: 'microwave',
79: 'oven',
80: 'toaster',
81: 'sink',
82: 'refrigerator',
84: 'book',
85: 'clock',
86: 'vase',
87: 'scissors',
88: 'teddy bear',
89: 'hair drier',
90: 'toothbrush'}}
Finally we inspect a single datum:
# Get first datum
image, target, datum_metadata = coco_dataset[0]
# Display image
plt.imshow(img)
plt.axis("off") # Optional: Hide the axes
plt.show()
# Bridge/convert ArrayLike's to PyTorch tensors
image = torch.as_tensor(image)
boxes = torch.as_tensor(target.boxes)
labels = torch.as_tensor(target.labels)
scores = torch.as_tensor(target.scores)
# Print shapes
print(f"{image.shape = }") # image has height 230 and weight 352
print(f"{boxes.shape = }") # there are 14 bounding boxes
print(f"{labels.shape = }")
print(f"{scores.shape = }")
# Print datum-level metadata
print(f"{datum_metadata = }") # this datum corresponds to image file 000000037777.jpg

image.shape = torch.Size([3, 230, 352])
boxes.shape = torch.Size([14, 4])
labels.shape = torch.Size([14])
scores.shape = torch.Size([14])
datum_metadata = {'id': 37777}
5 Conclusion#
In this how-to guide, we demonstrated how to wrap an object detection dataset to be MAITE compliant. Key implementation steps included:
Defining
__len__
and__getitem__
methodsChoosing a specific image
InputType
(e.g.,numpy.ndarray
ortorch.Tensor
)Creating a target dataclass that conforms to the
ObjectDetectionTarget
protocolSetting up
DatasetMetadata
andDatumMetadata
More dataset protocol details can be found here: maite.protocols.object_detection.Dataset.