maite.protocols.object_detection.Augmentation#
- class maite.protocols.object_detection.Augmentation(*args, **kwargs)[source]#
An augmentation protocol for the object detection subproblem.
An augmentation is expected to take a batch of data and return a modified version of that batch. Implementers must provide a single method that takes and returns a labeled data batch, where a labeled data batch is represented by a tuple of types
Sequence[ArrayLike]
,Sequence[ObjectDetectionTarget]
, andSequence[DatumMetadataType]
. These correspond to the model input batch type, model target batch type, and datum-level metadata batch type, respectively.Examples
We can write an implementer of the augmentation class as either a function or a class. The only requirement is that the object provide a
__call__
method that takes objects at least as general as the types promised in the protocol signature and return types at least as specific.We create a dummy set of data and use it to create a class the implements this lightweight protocol and augments the data:
>>> import numpy as np >>> np.random.seed(1) >>> import copy >>> from dataclasses import dataclass >>> from typing import Any >>> from maite.protocols import AugmentationMetadata, object_detection as od
First, we specify parameters that will be used to create the dummy dataset.
>>> N_DATAPOINTS = 3 # datapoints in dataset >>> N_CLASSES = 2 # possible classes that can be detected >>> C = 3 # number of color channels >>> H = 10 # img height >>> W = 10 # img width
Next, we create the input data to be used by the Augmentation. In this example, we create the following batch of object detection data:
xb
is the input batch data. Our batch will includeN_DATAPOINTS
number of samples. Note
that we initialize all of the data to zeros in this example to demonstrate the augmentations better.
yb
is the object detection target data, which in this example represents zero object
detections for each input datum (by having empty bounding boxes and class labels and scores).
mdb
is the associated metadata for each input datum.
>>> @dataclass ... class MyObjectDetectionTarget: ... boxes: np.ndarray ... labels: np.ndarray ... scores: np.ndarray >>> xb: od.InputBatchType = list(np.zeros((N_DATAPOINTS, C, H, W))) >>> yb: od.TargetBatchType = list( ... MyObjectDetectionTarget(boxes=np.empty(0), labels=np.empty(0), scores=np.empty(0)) ... for _ in range(N_DATAPOINTS) ... ) >>> mdb: od.DatumMetadataBatchType = list({"id": i} for i in range(N_DATAPOINTS)) >>> # Display the first datum in batch, first color channel, and only first 5 rows and cols >>> np.array(xb[0])[0][:5, :5] array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]])
Now we create the Augmentation, which will apply random noise (rounded to 3 decimal places) to the input data using the
numpy.random.random
andnp.round
functions.>>> np_noise = lambda shape: np.round(np.random.random(shape), 3) >>> class ImageAugmentation: ... def __init__(self, aug_func: Any, metadata: AugmentationMetadata): ... self.aug_func = aug_func ... self.metadata = metadata ... def __call__( ... self, ... batch: tuple[od.InputBatchType, od.TargetBatchType, od.DatumMetadataBatchType], ... ) -> tuple[od.InputBatchType, od.TargetBatchType, od.DatumMetadataBatchType]: ... xb, yb, mdb = batch ... # Copy data passed into the constructor to avoid mutating original inputs ... xb_aug = [copy.copy(input) for input in xb] ... # Add random noise to the input batch data, xb ... # (Note that all batch data dimensions (shapes) are the same in this example) ... shape = np.array(xb[0]).shape ... xb_aug = [x + self.aug_func(shape) for x in xb] ... # Note that this example augmentation only affects inputs--not targets ... return xb_aug, yb, mdb
We can typehint an instance of the above class as an Augmentation in the object detection domain:
>>> noise: od.Augmentation = ImageAugmentation(np_noise, metadata={"id": "np_rand_noise"})
Now we can apply the
noise
augmentation to our 3-tuple batch of data. Recall that our data was initialized to all zeros, so any non-zero values in the augmented data is a result of the augmentation.>>> xb_aug, yb_aug, mdb_aug = noise((xb, yb, mdb)) >>> # Display the first datum in batch, first color channel, and only first 5 rows and cols >>> np.array(xb_aug[0])[0][:5, :5] array([[0.417, 0.72 , 0. , 0.302, 0.147], [0.419, 0.685, 0.204, 0.878, 0.027], [0.801, 0.968, 0.313, 0.692, 0.876], [0.098, 0.421, 0.958, 0.533, 0.692], [0.989, 0.748, 0.28 , 0.789, 0.103]])
- Attributes:
- metadataAugmentationMetadata
A typed dictionary containing at least an ‘id’ field of type str
Methods
__call__(datum: Tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]]) -> Tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]]
Return a modified version of original data batch. A data batch is represented by a tuple of model input batch (as
Sequence ArrayLike
with elements of shape(C, H, W)
), model target batch (asSequence[ObjectDetectionTarget]
), and batch metadata (asSequence[DatumMetadataType]
), respectively.