maite.protocols.object_detection.Dataset#

class maite.protocols.object_detection.Dataset(*args, **kwargs)[source]#

A dataset protocol for object detection ML subproblem providing datum-level data access.

Implementers must provide index lookup (via __getitem__(ind: int) method) and support len (via __len__() method). Data elements looked up this way correspond to individual examples (as opposed to batches).

Indexing into or iterating over the an object detection dataset returns a Tuple of types ArrayLike, ObjectDetectionTarget, and DatumMetadataType. These correspond to the model input type, model target type, and datum-level metadata, respectively.

Examples

We create a dummy set of a data and use it to create a class the implements this lightweight dataset protocol:

>>> import numpy as np
>>> from dataclasses import dataclass
>>> from maite.protocols import (
...     DatumMetadata,
...     DatasetMetadata,
...     object_detection as od,
... )

Specify the parameters that will be used to create the dummy dataset.

>>> N_DATUM = 5  # data points in dataset
>>> N_CLASSES = 2  # possible classes that can be detected
>>> C = 3  # number of color channels
>>> H = 10  # image height
>>> W = 10  # image width

The dummy dataset will consist of a set of images, their associated annotations (detections), and some image-specific metadata (hour of day each image was taken).

We define helper functions to create the dummy annotations. Each image datum will be randomly assigned zero, one, or two detections. Annotations for each image consist of randomly generated bounding boxes and class labels.

>>> def generate_random_bbox(
...     n_classes: int, min_size: int = 2, max_size: int = 4
... ) -> np.ndarray:
...     # Generate random coordinates for top-left corner of bbox
...     x1 = np.random.randint(0, W - min_size)
...     y1 = np.random.randint(0, H - min_size)
...     # Generate random width and height, ensuring bounding box stays within image boundaries
...     bbox_width = np.random.randint(min_size, min(max_size, W - x1))
...     bbox_height = np.random.randint(min_size, min(max_size, H - y1))
...     # Set coordinates for bottom-right corner of bbox
...     x2 = x1 + bbox_width
...     y2 = y1 + bbox_height
...     # Pick random class label
...     label = np.random.choice(n_classes)
...     return np.array([x1, y1, x2, y2, label])
>>> def generate_random_annotation(max_num_detections: int = 2) -> np.ndarray:
...     num_detections = np.random.choice(max_num_detections + 1)
...     annotation = [generate_random_bbox(N_CLASSES) for _ in range(num_detections)]
...     return np.vstack(annotation) if num_detections > 0 else np.empty(0)

We now create the dummy dataset of images, corresponding annotations, and metadata.

>>> images: list[np.ndarray] = list(np.random.rand(N_DATUM, C, H, W))
>>> annotations: list[np.ndarray] = [
...     generate_random_annotation() for _ in range(N_DATUM)
... ]
>>> hour_of_day: list[int] = [np.random.choice(24) for _ in range(N_DATUM)]
>>> dataset: list[tuple] = list(zip(images, annotations, hour_of_day))

To support our MAITE dataset, we create an object detection target class that defines the boxes, labels, and scores for each detection in an image.

>>> @dataclass
... class MyObjectDetectionTarget:
...     boxes: np.ndarray
...     labels: np.ndarray
...     scores: np.ndarray

Lastly, we extend maite.protocols.DatumMetadata to hold datum-specifc metadata to add the notional hour of day field (in addition to the required unique id).

>>> class MyDatumMetadata(DatumMetadata):
...     hour_of_day: int

Constructing a compliant dataset now just involves a simple wrapper that fetches individual data points, where a data point is a single image, ground truth detection(s), and metadata.

>>> class ImageDataset:
...     # Set up required dataset-level metadata
...     metadata: DatasetMetadata = {
...         "id": "Dummy Dataset",
...         "index2label": {i: f"class_name_{i}" for i in range(N_CLASSES)}
...     }
...     def __init__(self, dataset: list[tuple[np.ndarray, np.ndarray, int]]):
...         self.dataset = dataset
...     def __len__(self) -> int:
...         return len(self.dataset)
...     def __getitem__(
...         self, index: int
...     ) -> tuple[np.ndarray, od.ObjectDetectionTarget, od.DatumMetadataType]:
...         if index < 0 or index >= len(self):
...             raise IndexError(f"Index {index} is out of range for the dataset, which has length {len(self)}.")
...         image, annotations, hour_of_day = self.dataset[index]
...         # Structure ground truth target
...         boxes, labels = [], []
...         for _, ann in enumerate(annotations):
...             bbox = ann[:-1]
...             label = ann[-1:]
...             if len(bbox) != 0:
...                 boxes.append(bbox)
...                 labels.append(label)
...         od_target = MyObjectDetectionTarget(
...             boxes=np.array(boxes), labels=np.array(labels), scores=np.ones(len(boxes))
...         )
...         # Structure datum-level metadata
...         datum_metadata: MyDatumMetadata = {"id": str(index), "hour_of_day": hour_of_day}
...         return image, od_target, datum_metadata

We can instantiate this class and type hint it as an object_detection.Dataset. By using type hinting, we permit a static typechecker to verify protocol compliance.

>>> maite_od_dataset: od.Dataset = ImageDataset(dataset)

Note that when writing a Dataset implementer, return types may be narrower than the return types promised by the protocol (np.ndarray is a subtype of ArrayLike), but the argument types must be at least as general as the argument types promised by the protocol.

Attributes:
metadataDatasetMetadata

A typed dictionary containing at least an ‘id’ field of type str

Methods

__getitem__(ind: int) -> Tuple[ArrayLike, ObjectDetectionTarget, DatumMetadataType]

Provide mapping-style access to dataset elements. Returned tuple elements correspond to model input type, model target type, and datum-specific metadata, respectively.

__len__() -> int

Return the number of data elements in the dataset.