maite.protocols.object_detection.Model#

class maite.protocols.object_detection.Model(*args, **kwargs)[source]#

A model protocol for the object detection ML subproblem.

Implementers must provide a __call__ method that operates on a batch of model inputs (as Sequence[ArrayLike]`s) and returns a batch of model targets (as `Sequence[ObjectDetectionTarget])

Examples

We create a simple MAITE-compliant object detection model, note it is a dummy model.

>>> from dataclasses import dataclass
>>> from typing import Sequence
>>> import numpy as np
>>> import maite.protocols.object_detection as od
>>> from maite.protocols import ModelMetadata

We define an object detection target dataclass as the output of the object detection model

>>> @dataclass
... class MyObjectDetectionTarget:
...     boxes: np.ndarray
...     labels: np.ndarray
...     scores: np.ndarray
...

Specify parameters that will be used to create a dummy dataset.

>>> N_DATAPOINTS = 2  # datapoints in dataset
>>> N_CLASSES = 5  # possible classes that can be detected
>>> C = 3  # number of color channels
>>> H = 32  # img height
>>> W = 32  # img width

Now create a batch of data to form the inputs of the MAITE’s object detection model.

>>> simple_batch: list[np.ndarray] = [
...     np.random.rand(C, H, W) for _ in range(N_DATAPOINTS)
... ]

We define a simple object detection model, note here there is not an actual object detection model. In the __call__ method, it just outputs the MyObjectDetectionTarget.

>>> class ObjectDetectionDummyModel:
...     metadata: ModelMetadata = {"id": "ObjectDetectionDummyModel"}
...
...     def __call__(
...         self, batch: od.InputBatchType
...     ) -> Sequence[MyObjectDetectionTarget]:
...         # For the simplicity, we don't provide an object detection model here, but the output from a model.
...         DETECTIONS_PER_IMG = (
...             2  # number of bounding boxes detections per image/datapoints
...         )
...         all_boxes = np.array(
...             [[1, 3, 5, 9], [2, 5, 8, 12], [4, 10, 8, 20], [3, 5, 6, 15]]
...         )  # all detection boxes for N_DATAPOINTS
...         all_predictions = list()
...         for datum_idx in range(N_DATAPOINTS):
...             boxes = all_boxes[datum_idx : datum_idx + DETECTIONS_PER_IMG]
...             labels = np.random.randint(N_CLASSES, size=DETECTIONS_PER_IMG)
...             scores = np.random.rand(DETECTIONS_PER_IMG)
...             predictions = MyObjectDetectionTarget(boxes, labels, scores)
...             all_predictions.append(predictions)
...         return all_predictions
...

We can instantiate this class and typehint it as a maite object detection model. By using typehinting, we permit a static typechecker to verify protocol compliance.

>>> od_dummy_model: od.Model = ObjectDetectionDummyModel()
>>> od_dummy_model.metadata
{'id': 'ObjectDetectionDummyModel'}
>>> predictions = od_dummy_model(simple_batch)
>>> predictions  
[MyObjectDetectionTarget(boxes=array([[ 1,  3,  5,  9], [ 2,  5,  8, 12]]), labels=array([..., ...]), scores=array([..., ...])),
MyObjectDetectionTarget(boxes=array([[ 2,  5,  8, 12], [ 4, 10,  8, 20]]), labels=array([..., ...]), scores=array([..., ...]))]
Attributes:
metadataModelMetadata

A typed dictionary containing at least an ‘id’ field of type str

Methods

__call__(input_batch: Sequence[ArrayLike]) -> Sequence[ObjectDetectionTarget]

Make a model prediction for inputs in input batch. Elements of input batch are expected in the shape (C, H, W).