:: Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
SPDX-License-Identifier: MIT Overview of MAITE Protocols =========================== MAITE provides protocols for the following AI components: - models - datasets - dataloaders - augmentations - metrics MAITE protocols specify expected interfaces of these components (i.e, a minimal set of required attributes, methods, and method type signatures) to promote interoperability in test and evaluation (T&E). This enables the creation of higher-level workflows (e.g., an ``evaluate`` utility) that can interact with any components that conform to the protocols. 1 Concept: Bridging ArrayLikes ------------------------------ MAITE uses a type called ``ArrayLike`` (following NumPy’s `interoperability approach `__) that helps components that natively use different flavors of tensors (e.g., NumPy ndarray, PyTorch Tensor, JAX ndarray) work together. In this example, the functions “type narrow” from ``ArrayLike`` to the type they want to work with internally. Note that this doesn’t necessarily require a conversion depending on the actual input type. .. code:: ipython3 import numpy as np import torch from maite.protocols import ArrayLike def my_numpy_fn(x: ArrayLike) -> np.ndarray: arr = np.asarray(x) # ... return arr def my_torch_fn(x: ArrayLike) -> torch.Tensor: tensor = torch.as_tensor(x) # ... return tensor # can apply NumPy function to PyTorch Tensor np_out = my_numpy_fn(torch.rand(2, 3)) # can apply PyTorch function to NumPy array torch_out = my_torch_fn(np.random.rand(2, 3)) # note: no performance hit from conversion when all `ArrayLike`s are from same library # or when can share the same underlying memory torch_out = my_torch_fn(torch.rand(2, 3)) By using bridging, we MAITE can permit implementers of the protocol to internally interact with their own types while exposing a more open interface to other MAITE-compliant components. 2 Data Types ------------ MAITE represents an *individual* data item as a tuple of: - input (i.e., image), - target (i.e., label), and - metadata (at the datum level) and a *batch* of data items as a tuple of: - input batches, - target batches, and - metadata batches. MAITE provides versions of ``Model``, ``Dataset``, ``DataLoader``, ``Augmentation``, and ``Metric`` protocols that correspond to different machine learning tasks (e.g. image classification, object detection) by parameterizing protocol interfaces on the particular input, target, and metadata types associated with that task. 2.1 Image Classification ~~~~~~~~~~~~~~~~~~~~~~~~ For image classification with ``Cl`` image classes, we have: .. code:: python # define type to store an id of each datum (additional fields can be added by defining structurally-assignable TypedDict) DatumMetadataType(TypedDict): id: str|int InputType: TypeAlias = ArrayLike # shape-(C, H, W) tensor with single image TargetType: TypeAlias = ArrayLike # shape-(Cl) tensor of one-hot encoded true class or predicted probabilities InputBatchType: TypeAlias = Sequence[ArrayLike] # element shape-(C, H, W) tensor of N images TargetBatchType: TypeAlias = Sequence[ArrayLike] # element shape-(Cl,) DatumMetadataBatchType: TypeAlias = Sequence[DatumMetadataType] Notes: - ``TargetType`` is used for both ground truth (coming from a dataset) and predictions (output from a model). So for a problem with 4 classes, - true label of class 2 would be one-hot encoded as ``[0, 0, 1, 0]`` - prediction from a model would be a vector of pseudo-probabilities, e.g., ``[0.1, 0.0, 0.7, 0.2]`` - ``InputType`` and ``InputBatchType`` are shown with shapes following PyTorch channels-first convention These type aliases along with the versions of the various component protocols that use these types can be imported from ``maite.protocols.image_classification`` (if necessary): .. code:: ipython3 # import protocol classes from maite.protocols.image_classification import ( Dataset, DataLoader, Model, Augmentation, Metric ) # import type aliases from maite.protocols.image_classification import ( InputType, TargetType, DatumMetadataType, InputBatchType, TargetBatchType, DatumMetadataBatchType ) Alternatively, image classification components and types can be accessed via the module directly: .. code:: ipython3 import maite.protocols.image_classification as ic # model: ic.Model = load_model(...) 2.2 Object Detection ~~~~~~~~~~~~~~~~~~~~ For object detection with ``D_i`` detections in an image ``i``, we have: .. code:: python # define type to store an id of each datum (additional fields can be added by defining structurally-assignable TypedDict) DatumMetadataType(TypedDict): id: str|int class ObjectDetectionTarget(Protocol): @property def boxes(self) -> ArrayLike: ... # shape-(D_i, 4) tensor of bounding boxes w/format X0, Y0, X1, Y1 @property def labels(self) -> ArrayLike: ... # shape-(D_i) tensor of labels for each box @property def scores(self) -> ArrayLike: ... # shape-(D_i) tensor of scores for each box (e.g., probabilities) InputType: TypeAlias = ArrayLike # shape-(C, H, W) tensor with single image TargetType: TypeAlias = ObjectDetectionTarget InputBatchType: TypeAlias = Sequence[ArrayLike] # sequence of N ArrayLikes each of shape (C, H, W) TargetBatchType: TypeAlias = Sequence[TargetType] # sequence of object detection "target" objects DatumMetadataBatchType: TypeAlias = Sequence[DatumMetadataType] Notes: - ``ObjectDetectionTarget`` contains a single label and score per box - ``InputType`` and ``InputBatchType`` are shown with shapes following PyTorch channels-first convention 3 Models -------- All models implement a ``__call__`` method that takes the ``InputBatchType`` and produces the ``TargetBatchType`` appropriate for the given machine learning task. .. code:: ipython3 import maite.protocols.image_classification as ic print(ic.Model.__doc__) .. parsed-literal:: A model protocol for the image classification ML subproblem. Implementers must provide a `__call__` method that operates on a batch of model inputs (as `Sequence[ArrayLike]) and returns a batch of model targets (as `Sequence[ArrayLike]`) Methods ------- __call__(input_batch: Sequence[ArrayLike]) -> Sequence[ArrayLike] Make a model prediction for inputs in input batch. Input batch is expected to be `Sequence[ArrayLike]` with each element of shape `(C, H, W)`. Attributes ---------- metadata : ModelMetadata A typed dictionary containing at least an 'id' field of type str Examples -------- We create a multinomial logistic regression classifier for a CIFAR-10-like dataset with 10 classes and shape-(3, 32, 32) images. >>> import maite.protocols.image_classification as ic >>> import numpy as np >>> import numpy.typing as npt >>> from maite.protocols import ArrayLike, ModelMetadata >>> from typing import Sequence Creating a MAITE-compliant model involves writing a `__call__` method that takes a batch of inputs and returns a batch of predictions (probabilities). >>> class LinearClassifier: ... def __init__(self) -> None: ... # Set up required metadata attribute using the default `ModelMetadata` type, ... # using class name for the ID ... self.metadata: ModelMetadata = {"id": self.__class__.__name__} ... ... # Initialize weights ... rng = np.random.default_rng(12345678) ... num_classes = 10 ... flattened_size = 3 * 32 * 32 ... self.weights = -0.2 + 0.4 * rng.random((flattened_size, num_classes)) ... self.bias = -0.2 + 0.4 * rng.random((1, num_classes)) ... ... def __call__(self, batch: Sequence[ArrayLike]) -> Sequence[npt.NDArray]: ... # Convert each element in batch to ndarray, flatten, ... # then combine into 4D array of shape-(N, C, H, W) ... batch_np = np.vstack([np.asarray(x).flatten() for x in batch]) ... ... # Send input batch through model ... out = batch_np @ self.weights + self.bias ... out = np.exp(out) / np.sum(np.exp(out), axis=1, keepdims=True) # softmax ... ... # Restructure to sequence of shape-(10,) probabilities ... return [row for row in out] We set up a test batch, instantiate the model, and apply it to the batch. >>> batch_size = 8 >>> rng = np.random.default_rng(12345678) >>> batch: Sequence[ArrayLike] = [-0.2 + 0.4 * rng.random((3, 32, 32)) for _ in range(batch_size)] >>> >>> model: ic.Model = LinearClassifier() >>> out = model(batch) We can now show the class probabilities returned by the model for each image in the batch. >>> for probs in out: ... print(np.round(probs, 2)) [0.16 0.1 0.16 0.14 0.04 0.02 0.06 0.04 0.17 0.1 ] [0.21 0.16 0.04 0.07 0.08 0.05 0.09 0.03 0.18 0.09] [0.15 0.11 0.13 0.11 0.09 0.09 0.07 0.04 0.19 0.02] [0.04 0.08 0.14 0.07 0.12 0.2 0.11 0.06 0.14 0.04] [0.03 0.08 0.06 0.05 0.17 0.18 0.09 0.03 0.12 0.19] [0.09 0.04 0.1 0.03 0.32 0.05 0.07 0.04 0.15 0.09] [0.15 0.05 0.1 0.05 0.11 0.14 0.04 0.08 0.08 0.2 ] [0.11 0.11 0.08 0.11 0.08 0.05 0.24 0.03 0.08 0.12] Note that when writing a Model implementer, return types may be narrower than the return types promised by the protocol (npt.NDArray is a subtype of ArrayLike), but the argument types must be at least as general as the argument types promised by the protocol. .. code:: ipython3 import maite.protocols.object_detection as od print(od.Model.__doc__) .. parsed-literal:: A model protocol for the object detection ML subproblem. Implementers must provide a `__call__` method that operates on a batch of model inputs (as `Sequence[ArrayLike]`s) and returns a batch of model targets (as `Sequence[ObjectDetectionTarget]`) Methods ------- __call__(input_batch: Sequence[ArrayLike]) -> Sequence[ObjectDetectionTarget] Make a model prediction for inputs in input batch. Elements of input batch are expected in the shape `(C, H, W)`. Attributes ---------- metadata : ModelMetadata A typed dictionary containing at least an 'id' field of type str Examples -------- We create a simple MAITE-compliant object detection model, note it is a dummy model. >>> from dataclasses import dataclass >>> from typing import Sequence >>> import numpy as np >>> import maite.protocols.object_detection as od >>> from maite.protocols import ModelMetadata We define an object detection target dataclass as the output of the object detection model >>> @dataclass ... class MyObjectDetectionTarget: ... boxes: np.ndarray ... labels: np.ndarray ... scores: np.ndarray ... Specify parameters that will be used to create a dummy dataset. >>> N_DATAPOINTS = 2 # datapoints in dataset >>> N_CLASSES = 5 # possible classes that can be detected >>> C = 3 # number of color channels >>> H = 32 # img height >>> W = 32 # img width Now create a batch of data to form the inputs of the MAITE's object detection model. >>> simple_batch: list[np.ndarray] = [ ... np.random.rand(C, H, W) for _ in range(N_DATAPOINTS) ... ] We define a simple object detection model, note here there is not an actual object detection model. In the __call__ method, it just outputs the MyObjectDetectionTarget. >>> class ObjectDetectionDummyModel: ... metadata: ModelMetadata = {"id": "ObjectDetectionDummyModel"} ... ... def __call__( ... self, batch: od.InputBatchType ... ) -> Sequence[MyObjectDetectionTarget]: ... # For the simplicity, we don't provide an object detection model here, but the output from a model. ... DETECTIONS_PER_IMG = ( ... 2 # number of bounding boxes detections per image/datapoints ... ) ... all_boxes = np.array( ... [[1, 3, 5, 9], [2, 5, 8, 12], [4, 10, 8, 20], [3, 5, 6, 15]] ... ) # all detection boxes for N_DATAPOINTS ... all_predictions = list() ... for datum_idx in range(N_DATAPOINTS): ... boxes = all_boxes[datum_idx : datum_idx + DETECTIONS_PER_IMG] ... labels = np.random.randint(N_CLASSES, size=DETECTIONS_PER_IMG) ... scores = np.random.rand(DETECTIONS_PER_IMG) ... predictions = MyObjectDetectionTarget(boxes, labels, scores) ... all_predictions.append(predictions) ... return all_predictions ... We can instantiate this class and typehint it as a maite object detection model. By using typehinting, we permit a static typechecker to verify protocol compliance. >>> od_dummy_model: od.Model = ObjectDetectionDummyModel() >>> od_dummy_model.metadata {'id': 'ObjectDetectionDummyModel'} >>> predictions = od_dummy_model(simple_batch) >>> predictions # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS [MyObjectDetectionTarget(boxes=array([[ 1, 3, 5, 9], [ 2, 5, 8, 12]]), labels=array([..., ...]), scores=array([..., ...])), MyObjectDetectionTarget(boxes=array([[ 2, 5, 8, 12], [ 4, 10, 8, 20]]), labels=array([..., ...]), scores=array([..., ...]))] 4 Datasets and DataLoaders -------------------------- ``Dataset``\ s provide access to single data items and ``DataLoader``\ s provide access to batches of data with the input, target, and metadata types corresponding to the given machine learning task. .. code:: ipython3 print(ic.Dataset.__doc__) .. parsed-literal:: A dataset protocol for image classification ML subproblem providing datum-level data access. Implementers must provide index lookup (via `__getitem__(ind: int)` method) and support `len` (via `__len__()` method). Data elements looked up this way correspond to individual examples (as opposed to batches). Indexing into or iterating over an image_classification dataset returns a `tuple` of types `ArrayLike`, `ArrayLike`, and `DatumMetadataType`. These correspond to the model input type, model target type, and datum-level metadata, respectively. Methods ------- __getitem__(ind: int) -> tuple[ArrayLike, ArrayLike, DatumMetadataType] Provide map-style access to dataset elements. Returned tuple elements correspond to model input type, model target type, and datum-specific metadata type, respectively. __len__() -> int Return the number of data elements in the dataset. Attributes ---------- metadata : DatasetMetadata A typed dictionary containing at least an 'id' field of type str Examples -------- We create a dummy set of data and use it to create a class that implements this lightweight dataset protocol: >>> import numpy as np >>> from typing import Any, TypedDict >>> from maite.protocols import ArrayLike, DatasetMetadata Assume we have 5 classes, 10 datapoints, and 10 target labels, and that we want to simply have an integer 'id' field in each datapoint's metadata: >>> N_CLASSES: int = 5 >>> N_DATUM: int = 10 >>> images: list[np.ndarray] = [np.random.rand(3, 32, 16) for _ in range(N_DATUM)] >>> targets: np.ndarray = np.eye(N_CLASSES)[np.random.choice(N_CLASSES, N_DATUM)] We can type our datum metadata as a maite.protocols DatumMetadata, or define our own TypedDict with additional fields >>> class MyDatumMetadata(DatumMetadata): ... hour_of_day: float ... >>> datum_metadata = [ MyDatumMetadata(id = i, hour_of_day=np.random.rand()*24) for i in range(N_DATUM) ] Constructing a compliant dataset just involves a simple wrapper that fetches individual datapoints, where a datapoint is a single image, target, metadata 3-tuple. >>> class ImageDataset: ... def __init__(self, ... dataset_name: str, ... index2label: dict[int,str], ... images: list[np.ndarray], ... targets: np.ndarray, ... datum_metadata: list[MyDatumMetadata]): ... self.images = images ... self.targets = targets ... self.metadata = DatasetMetadata({'id': dataset_name, 'index2label': index2label}) ... self._datum_metadata = datum_metadata ... def __len__(self) -> int: ... return len(images) ... def __getitem__(self, ind: int) -> tuple[np.ndarray, np.ndarray, MyDatumMetadata]: ... return self.images[ind], self.targets[ind], self._datum_metadata[ind] We can instantiate this class and typehint it as an image_classification.Dataset. By using typehinting, we permit a static typechecker to verify protocol compliance. >>> from maite.protocols import image_classification as ic >>> dataset: ic.Dataset = ImageDataset('a_dataset', ... {i: f"class_name_{i}" for i in range(N_CLASSES)}, ... images, targets, datum_metadata) Note that when writing a Dataset implementer, return types may be narrower than the return types promised by the protocol (np.ndarray is a subtype of ArrayLike), but the argument types must be at least as general as the argument types promised by the protocol. .. code:: ipython3 print(ic.DataLoader.__doc__) .. parsed-literal:: A dataloader protocol for the image classification ML subproblem providing batch-level data access. Implementers must provide an iterable object (returning an iterator via the `__iter__` method) that yields tuples containing batches of data. These tuples contain types `Sequence[ArrayLike]` (elements of shape `(C, H, W)`), `Sequence[ArrayLike]` (elements shape `(Cl, )`), and `Sequence[DatumMetadataType]`, which correspond to model input batch, model target type batch, and a datum metadata batch. Note: Unlike Dataset, this protocol does not require indexing support, only iterating. Methods ------- __iter__ -> Iterator[tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]]] Return an iterator over batches of data, where each batch contains a tuple of of model input batch (as `Sequence[ArrayLike]`), model target batch (as `Sequence[ArrayLike]`), and batched datum-level metadata (as `Sequence[DatumMetadataType]`), respectively. .. code:: ipython3 print(od.DataLoader.__doc__) .. parsed-literal:: A dataloader protocol for the object detection ML subproblem providing batch-level data access. Implementers must provide an iterable object (returning an iterator via the `__iter__` method) that yields tuples containing batches of data. These tuples contain types `Sequence[ArrayLike]` (elements of shape `(C, H, W)`), `Sequence[ObjectDetectionTarget]`, and `Sequence[DatumMetadataType]`, which correspond to model input batch, model target batch, and a datum metadata batch. Note: Unlike Dataset, this protocol does not require indexing support, only iterating. Methods ------- __iter__ -> Iterator[tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]]] Return an iterator over batches of data, where each batch contains a tuple of of model input batch (as `Sequence[ArrayLike]`), model target batch (as `Sequence[ObjectDetectionTarget]`), and batched datum-level metadata (as `Sequence[DatumMetadataType]]`), respectively. 5 Augmentations --------------- ``Augmentation``\ s take in and return a batch of data with the ``InputBatchType``, ``TargetBatchType``, and ``DatumMetadataBatchType`` types corresponding to the given machine learning task. Augmentations can access the datum-level metadata associated with each data item to potentially tailor the augmentation to individual items. Augmentations can also associate new datum-level metadata with each data item, e.g., documenting aspects of the actual change that was applied (e.g., the actual rotation angle sampled from a range of possible angles). .. code:: ipython3 print(ic.Augmentation.__doc__) .. parsed-literal:: An augmentation protocol for the image classification subproblem. An augmentation is expected to take a batch of data and return a modified version of that batch. Implementers must provide a single method that takes and returns a labeled data batch, where a labeled data batch is represented by a tuple of types `Sequence[ArrayLike]` (with elements of shape `(C, H, W)`), `Sequence[ArrayLike]` (with elements of shape `(Cl, )`), and `Sequence[DatumMetadataType]`. These correspond to the model input batch type, model target batch type, and datum-level metadata batch type, respectively. Methods ------- __call__(datum: tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]]) -> tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]]) Return a modified version of original data batch. A data batch is represented by a tuple of model input batch (as `Sequence[ArrayLike]` with elements of shape `(C, H, W)`), model target batch (as `Sequence[ArrayLike]` with elements of shape `(Cl,)`), and batch metadata (as `Sequence[DatumMetadataType]`), respectively. Attributes ---------- metadata : AugmentationMetadata A typed dictionary containing at least an 'id' field of type str Examples -------- We can write an implementer of the augmentation class as either a function or a class. The only requirement is that the object provide a __call__ method that takes objects at least as general as the types promised in the protocol signature and return types at least as specific. >>> import copy >>> import numpy as np >>> from typing import Any >>> from collections.abc import Sequence >>> from maite.protocols import ArrayLike, DatumMetadata, AugmentationMetadata >>> >>> class EnrichedDatumMetadata(DatumMetadata): ... new_key: int # add a field to those already in DatumMetadata ... >>> class ImageAugmentation: ... def __init__(self, aug_name: str): ... self.metadata: AugmentationMetadata = {'id': aug_name} ... def __call__( ... self, ... data_batch: tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadata]] ... ) -> tuple[Sequence[np.ndarray], Sequence[np.ndarray], Sequence[EnrichedDatumMetadata]]: ... inputs, targets, mds = data_batch ... # We copy data passed into the constructor to avoid mutating original inputs ... # By using np.ndarray constructor, the static type-checker will let us treat ... # generic ArrayLike as a more narrow return type ... inputs_aug = [copy.copy(np.array(input)) for input in inputs] ... targets_aug = [copy.copy(np.array(target)) for target in targets] ... # Modify inputs_aug, targets_aug, or mds_aug as needed ... # In this example, we just add a new metadata field ... mds_aug = [] ... for i, md in enumerate(mds): ... mds_aug.append(EnrichedDatumMetadata(**md, new_key=i)) ... return inputs_aug, targets_aug, mds_aug We can typehint an instance of the above class as an Augmentation in the image_classification domain: >>> from maite.protocols import image_classification as ic >>> im_aug: ic.Augmentation = ImageAugmentation(aug_name = 'an_augmentation') .. code:: ipython3 print(od.Augmentation.__doc__) .. parsed-literal:: An augmentation protocol for the object detection subproblem. An augmentation is expected to take a batch of data and return a modified version of that batch. Implementers must provide a single method that takes and returns a labeled data batch, where a labeled data batch is represented by a tuple of types `Sequence[ArrayLike]`, `Sequence[ObjectDetectionTarget]`, and `Sequence[DatumMetadataType]`. These correspond to the model input batch type, model target batch type, and datum-level metadata batch type, respectively. Methods ------- __call__(datum: Tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]]) -> Tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]] Return a modified version of original data batch. A data batch is represented by a tuple of model input batch (as `Sequence ArrayLike` with elements of shape `(C, H, W)`), model target batch (as `Sequence[ObjectDetectionTarget]`), and batch metadata (as `Sequence[DatumMetadataType]`), respectively. Attributes ---------- metadata : AugmentationMetadata A typed dictionary containing at least an 'id' field of type str Examples -------- We can write an implementer of the augmentation class as either a function or a class. The only requirement is that the object provide a `__call__` method that takes objects at least as general as the types promised in the protocol signature and return types at least as specific. We create a dummy set of data and use it to create a class the implements this lightweight protocol and augments the data: >>> import numpy as np >>> np.random.seed(1) >>> import copy >>> from dataclasses import dataclass >>> from typing import Any >>> from maite.protocols import AugmentationMetadata, object_detection as od First, we specify parameters that will be used to create the dummy dataset. >>> N_DATAPOINTS = 3 # datapoints in dataset >>> N_CLASSES = 2 # possible classes that can be detected >>> C = 3 # number of color channels >>> H = 10 # img height >>> W = 10 # img width Next, we create the input data to be used by the Augmentation. In this example, we create the following batch of object detection data: • `xb` is the input batch data. Our batch will include `N_DATAPOINTS` number of samples. Note that we initialize all of the data to zeros in this example to demonstrate the augmentations better. • `yb` is the object detection target data, which in this example represents zero object detections for each input datum (by having empty bounding boxes and class labels and scores). • `mdb` is the associated metadata for each input datum. >>> @dataclass ... class MyObjectDetectionTarget: ... boxes: np.ndarray ... labels: np.ndarray ... scores: np.ndarray >>> xb: od.InputBatchType = list(np.zeros((N_DATAPOINTS, C, H, W))) >>> yb: od.TargetBatchType = list( ... MyObjectDetectionTarget(boxes=np.empty(0), labels=np.empty(0), scores=np.empty(0)) ... for _ in range(N_DATAPOINTS) ... ) >>> mdb: od.DatumMetadataBatchType = list({"id": i} for i in range(N_DATAPOINTS)) >>> # Display the first datum in batch, first color channel, and only first 5 rows and cols >>> np.array(xb[0])[0][:5, :5] array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]) Now we create the Augmentation, which will apply random noise (rounded to 3 decimal places) to the input data using the `numpy.random.random` and `np.round` functions. >>> np_noise = lambda shape: np.round(np.random.random(shape), 3) >>> class ImageAugmentation: ... def __init__(self, aug_func: Any, metadata: AugmentationMetadata): ... self.aug_func = aug_func ... self.metadata = metadata ... def __call__( ... self, ... batch: tuple[od.InputBatchType, od.TargetBatchType, od.DatumMetadataBatchType], ... ) -> tuple[od.InputBatchType, od.TargetBatchType, od.DatumMetadataBatchType]: ... xb, yb, mdb = batch ... # Copy data passed into the constructor to avoid mutating original inputs ... xb_aug = [copy.copy(input) for input in xb] ... # Add random noise to the input batch data, xb ... # (Note that all batch data dimensions (shapes) are the same in this example) ... shape = np.array(xb[0]).shape ... xb_aug = [x + self.aug_func(shape) for x in xb] ... # Note that this example augmentation only affects inputs--not targets ... return xb_aug, yb, mdb We can typehint an instance of the above class as an Augmentation in the object detection domain: >>> noise: od.Augmentation = ImageAugmentation(np_noise, metadata={"id": "np_rand_noise"}) Now we can apply the `noise` augmentation to our 3-tuple batch of data. Recall that our data was initialized to all zeros, so any non-zero values in the augmented data is a result of the augmentation. >>> xb_aug, yb_aug, mdb_aug = noise((xb, yb, mdb)) >>> # Display the first datum in batch, first color channel, and only first 5 rows and cols >>> np.array(xb_aug[0])[0][:5, :5] array([[0.417, 0.72 , 0. , 0.302, 0.147], [0.419, 0.685, 0.204, 0.878, 0.027], [0.801, 0.968, 0.313, 0.692, 0.876], [0.098, 0.421, 0.958, 0.533, 0.692], [0.989, 0.748, 0.28 , 0.789, 0.103]]) 6 Metrics --------- The ``Metric`` protocol is inspired by the design of existing libraries like Torchmetrics and Torcheval. The ``update`` method operates on batches of predictions and truth labels by either caching them for later computation of the metric (via ``compute``) or updating sufficient statistics in an online fashion. .. code:: ipython3 print(ic.Metric.__doc__) .. parsed-literal:: A metric protocol for the image classification ML subproblem. A metric in this sense is expected to measure the level of agreement between model predictions and ground-truth labels. Methods ------- update(preds: Sequence[ArrayLike], targets: Sequence[ArrayLike]) -> None Add predictions and targets to metric's cache for later calculation. Both preds and targets are expected to be sequences with elements of shape `(Cl,)`. compute() -> dict[str, Any] Compute metric value(s) for currently cached predictions and targets, returned as a dictionary. reset() -> None Clear contents of current metric's cache of predictions and targets. Attributes ---------- metadata : MetricMetadata A typed dictionary containing at least an 'id' field of type str Examples -------- Create a basic accuracy metric and test it on a small example dataset: >>> from typing import Any, Sequence >>> import numpy as np >>> from maite.protocols import ArrayLike >>> from maite.protocols import image_classification as ic >>> class MyAccuracy: ... metadata: MetricMetadata = {'id': 'Example Multiclass Accuracy'} ... ... def __init__(self): ... self._total = 0 ... self._correct = 0 ... ... def reset(self) -> None: ... self._total = 0 ... self._correct = 0 ... ... def update(self, preds: Sequence[ArrayLike], targets: Sequence[ArrayLike]) -> None: ... model_probs = [np.array(r) for r in preds] ... true_onehot = [np.array(r) for r in targets] ... ... # Stack into single array, convert to class indices ... model_classes = np.vstack(model_probs).argmax(axis=1) ... truth_classes = np.vstack(true_onehot).argmax(axis=1) ... ... # Compare classes and update running counts ... same = (model_classes == truth_classes) ... self._total += len(same) ... self._correct += same.sum() ... ... def compute(self) -> dict[str, Any]: ... if self._total > 0: ... return {"accuracy": self._correct / self._total} ... else: ... raise Exception("No batches processed yet.") Instantiate this class and typehint it as an image_classification.Metric. By using typehinting, permits a static typechecker to check protocol compliance. >>> accuracy: ic.Metric = MyAccuracy() To use the metric call update() for each batch of predictions and truth values and call compute() to calculate the final metric values. >>> # batch 1 >>> model_probs = [np.array([0.8, 0.1, 0.0, 0.1]), np.array([0.1, 0.2, 0.6, 0.1])] # predicted classes: 0, 2 >>> true_onehot = [np.array([1.0, 0.0, 0.0, 0.0]), np.array([0.0, 1.0, 0.0, 0.0])] # true classes: 0, 1 >>> accuracy.update(model_probs, true_onehot) >>> print(accuracy.compute()) {'accuracy': 0.5} >>> >>> # batch 2 >>> model_probs = [np.array([0.1, 0.1, 0.7, 0.1]), np.array([0.0, 0.1, 0.0, 0.9])] # predicted classes: 2, 3 >>> true_onehot = [np.array([0.0, 0.0, 1.0, 0.0]), np.array([0.0, 0.0, 0.0, 1.0])] # true classes: 2, 3 >>> accuracy.update(model_probs, true_onehot) >>> >>> print(accuracy.compute()) {'accuracy': 0.75} .. code:: ipython3 print(od.Metric.__doc__) .. parsed-literal:: A metric protocol for the object detection ML subproblem. A metric in this sense is expected to measure the level of agreement between model predictions and ground-truth labels. Methods ------- update(preds: Sequence[ObjectDetectionTarget], targets: Sequence[ObjectDetectionTarget]) -> None Add predictions and targets to metric's cache for later calculation. compute() -> dict[str, Any] Compute metric value(s) for currently cached predictions and targets, returned as a dictionary. reset() -> None Clear contents of current metric's cache of predictions and targets. Attributes ---------- metadata : MetricMetadata A typed dictionary containing at least an 'id' field of type str Examples -------- Below, we write and test a class that implements the Metric protocol for object detection. For simplicity, the metric will compute the intersection over union (IoU) averaged over all predicted and associated ground truth boxes for a single image, and then take the mean over the per-image means. Note that when writing a `Metric` implementer, return types may be narrower than the return types promised by the protocol, but the argument types must be at least as general as the argument types promised by the protocol. >>> from dataclasses import dataclass >>> from maite.protocols import ArrayLike, MetricMetadata >>> from typing import Any, Sequence >>> import maite.protocols.object_detection as od >>> import numpy as np >>> class MyIoUMetric: ... ... def __init__(self, id: str): ... self.pred_boxes = [] # elements correspond to predicted boxes in single image ... self.target_boxes = [] # elements correspond to ground truth boxes in single image ... # Store provided id for this metric instance ... self.metadata = MetricMetadata( ... id=id ... ) ... ... def reset(self) -> None: ... self.pred_boxes = [] ... self.target_boxes = [] ... ... def update( ... self, ... pred_batch: Sequence[od.ObjectDetectionTarget], ... target_batch: Sequence[od.ObjectDetectionTarget], ... ) -> None: ... self.pred_boxes.extend(pred_batch) ... self.target_boxes.extend(target_batch) ... ... @staticmethod ... def iou_vec(boxes_a: ArrayLike, boxes_b: ArrayLike) -> np.ndarray: ... # Break up points into separate columns ... x0a, y0a, x1a, y1a = np.split(boxes_a, 4, axis=1) ... x0b, y0b, x1b, y1b = np.split(boxes_b, 4, axis=1) ... # Calculate intersections ... xi_0, yi_0 = np.split( ... np.maximum(np.append(x0a, y0a, axis=1), np.append(x0b, y0b, axis=1)), ... 2, ... axis=1, ... ) ... xi_1, yi_1 = np.split( ... np.minimum(np.append(x1a, y1a, axis=1), np.append(x1b, y1b, axis=1)), ... 2, ... axis=1, ... ) ... ints: np.ndarray = np.maximum(0, xi_1 - xi_0) * np.maximum(0, yi_1 - yi_0) ... # Calculate unions (as sum of areas minus their intersection) ... unions: np.ndarray = ( ... (x1a - x0a) * (y1a - y0a) ... + (x1b - x0b) * (y1b - y0b) ... - (xi_1 - xi_0) * (yi_1 - yi_0) ... ) ... return ints / unions ... ... def compute(self) -> dict[str, Any]: ... mean_iou_by_img: list[float] = [] ... for pred_box, tgt_box in zip(self.pred_boxes, self.target_boxes): ... single_img_ious = self.iou_vec(pred_box.boxes, tgt_box.boxes) ... mean_iou_by_img.append(float(np.mean(single_img_ious))) ... return {"mean_iou": np.mean(np.array(mean_iou_by_img))} ... Now we can instantiate our IoU Metric class: >>> iou_metric: od.Metric = MyIoUMetric(id="IoUMetric") To use the metric, we populate two lists that encode predicted object detections and ground truth object detections for a single image. (Ordinarily, predictions would be produced by a model.) >>> prediction_boxes: list[tuple[int, int, int, int]] = [ ... (1, 1, 12, 12), ... (100, 100, 120, 120), ... (180, 180, 270, 270), ... ] >>> target_boxes: list[tuple[int, int, int, int]] = [ ... (1, 1, 10, 10), ... (100, 100, 120, 120), ... (200, 200, 300, 300), ... ] The MAITE Metric protocol requires the `pred_batch` and `target_batch` arguments to the `update` method to be assignable to Sequence[ObjectDetectionTarget] (where ObjectDetectionTarget encodes detections in a single image). We define an implementation of ObjectDetectionTarget and use it to pass ground truth and predicted detections. >>> @dataclass ... class ObjectDetectionTargetImpl: ... boxes: np.ndarray ... labels: np.ndarray ... scores: np.ndarray >>> num_boxes = len(target_boxes) >>> fake_labels = np.random.randint(0, 9, num_boxes) >>> fake_scores = np.zeros(num_boxes) >>> pred_batch = [ ... ObjectDetectionTargetImpl( ... boxes=np.array(prediction_boxes), labels=fake_labels, scores=fake_scores ... ) ... ] >>> target_batch: Sequence[ObjectDetectionTargetImpl] = [ ... ObjectDetectionTargetImpl( ... boxes=np.array(target_boxes), labels=fake_labels, scores=fake_scores ... ) ... ] Finally, we call `update` using this one-element batch, compute the metric value, and print it: >>> iou_metric.update(pred_batch, target_batch) >>> print(iou_metric.compute()) {'mean_iou': 0.6802112029384757} 7 Workflows ----------- MAITE provides high-level utilities for common workflows such as ``evaluate`` and ``predict``. They can be called with either ``Dataset``\ s or ``DataLoader``\ s, and with optional ``Augmentation``. The ``evaluate`` function can optionally return the model predictions and (potentially-augmented) data batches used during inference. The ``predict`` function returns the model predictions and (potentially-augmented) data batches used during inference, essentially calling ``evaluate`` with a dummy metric. .. code:: ipython3 from maite.workflows import evaluate, predict # we can also import from object_detection module # where the function call signature is the same .. code:: ipython3 print(evaluate.__doc__) .. parsed-literal:: Evaluate a model's performance on data according to some metric with optional augmentation. Some data source (either a dataloader or a dataset) must be provided or an InvalidArgument exception is raised. Parameters ---------- model : SomeModel Maite Model object. metric : SomeMetric | None, (default=None) Compatible maite Metric. dataloader : SomeDataloader | None, (default=None) Compatible maite dataloader. dataset : SomeDataset | None, (default=None) Compatible maite dataset. batch_size : int, (default=1) Batch size for use with dataset (ignored if dataset=None). augmentation : SomeAugmentation | None, (default=None) Compatible maite augmentation. return_augmented_data : bool, (default=False) Set to True to return post-augmentation data as a function output. return_preds : bool, (default=False) Set to True to return raw predictions as a function output. Returns ------- tuple[dict[str, Any], Sequence[TargetType], Sequence[tuple[InputBatchType, TargetBatchType, DatumMetadataBatchType]]] Tuple of returned metric value, sequence of model predictions, and sequence of data batch tuples fed to the model during inference. The actual types represented by InputBatchType, TargetBatchType, and DatumMetadataBatchType will vary by the domain of the components provided as input arguments (e.g. image classification or object detection.) Note that the second and third return arguments will be empty if return_augmented_data is False or return_preds is False, respectively. .. code:: ipython3 print(predict.__doc__) .. parsed-literal:: Make predictions for a given model & data source with optional augmentation. Some data source (either a dataloader or a dataset) must be provided or an InvalidArgument exception is raised. Parameters ---------- model : SomeModel Maite Model object. dataloader : SomeDataloader | None, (default=None) Compatible maite dataloader. dataset : SomeDataset | None, (default=None) Compatible maite dataset. batch_size : int, (default=1) Batch size for use with dataset (ignored if dataset=None). augmentation : SomeAugmentation | None, (default=None) Compatible maite augmentation. Returns ------- tuple[Sequence[SomeTargetBatchType], Sequence[tuple[SomeInputBatchType, SomeTargetBatchType, SomeMetadataBatchType]], A tuple of the predictions (as a sequence of batches) and a sequence of tuples containing the information associated with each batch.