maite.protocols.image_classification.Augmentation#

class maite.protocols.image_classification.Augmentation(*args, **kwargs)[source]#

An augmentation protocol for the image classification subproblem.

An augmentation is expected to take a batch of data and return a modified version of that batch. Implementers must provide a single method that takes and returns a labeled data batch, where a labeled data batch is represented by a tuple of types Sequence[ArrayLike] (with elements of shape (C, H, W)), Sequence[ArrayLike] (with elements of shape (Cl, )), and Sequence[DatumMetadataType]. These correspond to the model input batch type, model target batch type, and datum-level metadata batch type, respectively.

Examples

We can write an implementer of the augmentation class as either a function or a class. The only requirement is that the object provide a __call__ method that takes objects at least as general as the types promised in the protocol signature and return types at least as specific.

>>> import copy
>>> import numpy as np
>>> from typing import Any
>>> from collections.abc import Sequence
>>> from maite.protocols import ArrayLike, DatumMetadata, AugmentationMetadata
>>>
>>> class EnrichedDatumMetadata(DatumMetadata):
...     new_key: int  # add a field to those already in DatumMetadata
...
>>> class ImageAugmentation:
...     def __init__(self, aug_name: str):
...         self.metadata: AugmentationMetadata = {'id': aug_name}
...     def __call__(
...         self,
...         data_batch: tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadata]]
...     ) -> tuple[Sequence[np.ndarray], Sequence[np.ndarray], Sequence[EnrichedDatumMetadata]]:
...         inputs, targets, mds = data_batch
...         # We copy data passed into the constructor to avoid mutating original inputs
...         # By using np.ndarray constructor, the static type-checker will let us treat
...         # generic ArrayLike as a more narrow return type
...         inputs_aug = [copy.copy(np.array(input)) for input in inputs]
...         targets_aug = [copy.copy(np.array(target)) for target in targets]
...         # Modify inputs_aug, targets_aug, or mds_aug as needed
...         # In this example, we just add a new metadata field
...         mds_aug = []
...         for i, md in enumerate(mds):
...             mds_aug.append(EnrichedDatumMetadata(**md, new_key=i))
...         return inputs_aug, targets_aug, mds_aug

We can typehint an instance of the above class as an Augmentation in the image_classification domain:

>>> from maite.protocols import image_classification as ic
>>> im_aug: ic.Augmentation = ImageAugmentation(aug_name = 'an_augmentation')
Attributes:
metadataAugmentationMetadata

A typed dictionary containing at least an ‘id’ field of type str

Methods

__call__(datum: tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]]) -> tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]])

Return a modified version of original data batch. A data batch is represented by a tuple of model input batch (as Sequence[ArrayLike] with elements of shape (C, H, W)), model target batch (as Sequence[ArrayLike] with elements of shape (Cl,)), and batch metadata (as Sequence[DatumMetadataType]), respectively.