maite.protocols.image_classification.Augmentation#
- class maite.protocols.image_classification.Augmentation(*args, **kwargs)[source]#
An augmentation protocol for the image classification subproblem.
An augmentation is expected to take a batch of data and return a modified version of that batch. Implementers must provide a single method that takes and returns a labeled data batch, where a labeled data batch is represented by a tuple of types
Sequence[ArrayLike]
(with elements of shape(C, H, W)
),Sequence[ArrayLike]
(with elements of shape(Cl, )
), andSequence[DatumMetadataType]
. These correspond to the model input batch type, model target batch type, and datum-level metadata batch type, respectively.Examples
We can write an implementer of the augmentation class as either a function or a class. The only requirement is that the object provide a __call__ method that takes objects at least as general as the types promised in the protocol signature and return types at least as specific.
>>> import copy >>> import numpy as np >>> from typing import Any >>> from collections.abc import Sequence >>> from maite.protocols import ArrayLike, DatumMetadata, AugmentationMetadata >>> >>> class EnrichedDatumMetadata(DatumMetadata): ... new_key: int # add a field to those already in DatumMetadata ... >>> class ImageAugmentation: ... def __init__(self, aug_name: str): ... self.metadata: AugmentationMetadata = {'id': aug_name} ... def __call__( ... self, ... data_batch: tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadata]] ... ) -> tuple[Sequence[np.ndarray], Sequence[np.ndarray], Sequence[EnrichedDatumMetadata]]: ... inputs, targets, mds = data_batch ... # We copy data passed into the constructor to avoid mutating original inputs ... # By using np.ndarray constructor, the static type-checker will let us treat ... # generic ArrayLike as a more narrow return type ... inputs_aug = [copy.copy(np.array(input)) for input in inputs] ... targets_aug = [copy.copy(np.array(target)) for target in targets] ... # Modify inputs_aug, targets_aug, or mds_aug as needed ... # In this example, we just add a new metadata field ... mds_aug = [] ... for i, md in enumerate(mds): ... mds_aug.append(EnrichedDatumMetadata(**md, new_key=i)) ... return inputs_aug, targets_aug, mds_aug
We can typehint an instance of the above class as an Augmentation in the image_classification domain:
>>> from maite.protocols import image_classification as ic >>> im_aug: ic.Augmentation = ImageAugmentation(aug_name = 'an_augmentation')
- Attributes:
- metadataAugmentationMetadata
A typed dictionary containing at least an ‘id’ field of type str
Methods
__call__(datum: tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]]) -> tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]])
Return a modified version of original data batch. A data batch is represented by a tuple of model input batch (as
Sequence[ArrayLike]
with elements of shape(C, H, W)
), model target batch (asSequence[ArrayLike]
with elements of shape(Cl,)
), and batch metadata (asSequence[DatumMetadataType]
), respectively.