maite.protocols.object_detection.Augmentation#

class maite.protocols.object_detection.Augmentation(*args, **kwargs)[source]#

An augmentation protocol for the object detection subproblem.

An augmentation is expected to take a batch of data and return a modified version of that batch. Implementers must provide a single method that takes and returns a labeled data batch, where a labeled data batch is represented by a tuple of types Sequence[ArrayLike], Sequence[ObjectDetectionTarget], and Sequence[DatumMetadataType]. These correspond to the model input batch type, model target batch type, and datum-level metadata batch type, respectively.

Examples

We can write an implementer of the augmentation class as either a function or a class. The only requirement is that the object provide a __call__ method that takes objects at least as general as the types promised in the protocol signature and return types at least as specific.

We create a dummy set of data and use it to create a class the implements this lightweight protocol and augments the data:

>>> import numpy as np
>>> np.random.seed(1)
>>> import copy
>>> from dataclasses import dataclass
>>> from typing import Any
>>> from maite.protocols import AugmentationMetadata, object_detection as od

First, we specify parameters that will be used to create the dummy dataset.

>>> N_DATAPOINTS = 3  # datapoints in dataset
>>> N_CLASSES = 2  # possible classes that can be detected
>>> C = 3  # number of color channels
>>> H = 10  # img height
>>> W = 10  # img width

Next, we create the input data to be used by the Augmentation. In this example, we create the following batch of object detection data:

  • xb is the input batch data. Our batch will include N_DATAPOINTS number of samples. Note

that we initialize all of the data to zeros in this example to demonstrate the augmentations better.

  • yb is the object detection target data, which in this example represents zero object

detections for each input datum (by having empty bounding boxes and class labels and scores).

  • mdb is the associated metadata for each input datum.

>>> @dataclass
... class MyObjectDetectionTarget:
...     boxes: np.ndarray
...     labels: np.ndarray
...     scores: np.ndarray
>>> xb: od.InputBatchType = list(np.zeros((N_DATAPOINTS, C, H, W)))
>>> yb: od.TargetBatchType = list(
...     MyObjectDetectionTarget(boxes=np.empty(0), labels=np.empty(0), scores=np.empty(0))
...     for _ in range(N_DATAPOINTS)
... )
>>> mdb: od.DatumMetadataBatchType = list({"id": i} for i in range(N_DATAPOINTS))
>>> # Display the first datum in batch, first color channel, and only first 5 rows and cols
>>> np.array(xb[0])[0][:5, :5]
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]])

Now we create the Augmentation, which will apply random noise (rounded to 3 decimal places) to the input data using the numpy.random.random and np.round functions.

>>> np_noise = lambda shape: np.round(np.random.random(shape), 3)
>>> class ImageAugmentation:
...     def __init__(self, aug_func: Any, metadata: AugmentationMetadata):
...         self.aug_func = aug_func
...         self.metadata = metadata
...     def __call__(
...         self,
...         batch: tuple[od.InputBatchType, od.TargetBatchType, od.DatumMetadataBatchType],
...     ) -> tuple[od.InputBatchType, od.TargetBatchType, od.DatumMetadataBatchType]:
...         xb, yb, mdb = batch
...         # Copy data passed into the constructor to avoid mutating original inputs
...         xb_aug = [copy.copy(input) for input in xb]
...         # Add random noise to the input batch data, xb
...         # (Note that all batch data dimensions (shapes) are the same in this example)
...         shape = np.array(xb[0]).shape
...         xb_aug = [x + self.aug_func(shape) for x in xb]
...         # Note that this example augmentation only affects inputs--not targets
...         return xb_aug, yb, mdb

We can typehint an instance of the above class as an Augmentation in the object detection domain:

>>> noise: od.Augmentation = ImageAugmentation(np_noise, metadata={"id": "np_rand_noise"})

Now we can apply the noise augmentation to our 3-tuple batch of data. Recall that our data was initialized to all zeros, so any non-zero values in the augmented data is a result of the augmentation.

>>> xb_aug, yb_aug, mdb_aug = noise((xb, yb, mdb))
>>> # Display the first datum in batch, first color channel, and only first 5 rows and cols
>>> np.array(xb_aug[0])[0][:5, :5]
array([[0.417, 0.72 , 0.   , 0.302, 0.147],
       [0.419, 0.685, 0.204, 0.878, 0.027],
       [0.801, 0.968, 0.313, 0.692, 0.876],
       [0.098, 0.421, 0.958, 0.533, 0.692],
       [0.989, 0.748, 0.28 , 0.789, 0.103]])
Attributes:
metadataAugmentationMetadata

A typed dictionary containing at least an ‘id’ field of type str

Methods

__call__(datum: Tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]]) -> Tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]]

Return a modified version of original data batch. A data batch is represented by a tuple of model input batch (as Sequence ArrayLike with elements of shape (C, H, W)), model target batch (as Sequence[ObjectDetectionTarget]), and batch metadata (as Sequence[DatumMetadataType]), respectively.