Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY<br/>
Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).<br/>
SPDX-License-Identifier: MIT

Overview of MAITE Protocols#

MAITE provides protocols for the following AI components:

  • models

  • datasets

  • dataloaders

  • augmentations

  • metrics

MAITE protocols specify expected interfaces of these components (i.e, a minimal set of required attributes, methods, and method type signatures) to promote interoperability in test and evaluation (T&E). This enables the creation of higher-level workflows (e.g., an evaluate utility) that can interact with any components that conform to the protocols.

1 Concept: Bridging ArrayLikes#

MAITE uses a type called ArrayLike (following NumPy’s interoperability approach) that helps components that natively use different flavors of tensors (e.g., NumPy ndarray, PyTorch Tensor, JAX ndarray) work together.

In this example, the functions “type narrow” from ArrayLike to the type they want to work with internally. Note that this doesn’t necessarily require a conversion depending on the actual input type.

import numpy as np
import torch

from maite.protocols import ArrayLike

def my_numpy_fn(x: ArrayLike) -> np.ndarray:
    arr = np.asarray(x)
    # ...
    return arr

def my_torch_fn(x: ArrayLike) -> torch.Tensor:
    tensor = torch.as_tensor(x)
    # ...
    return tensor

# can apply NumPy function to PyTorch Tensor
np_out = my_numpy_fn(torch.rand(2, 3))

# can apply PyTorch function to NumPy array
torch_out = my_torch_fn(np.random.rand(2, 3))

# note: no performance hit from conversion when all `ArrayLike`s are from same library
# or when can share the same underlying memory
torch_out = my_torch_fn(torch.rand(2, 3))

By using bridging, we MAITE can permit implementers of the protocol to internally interact with their own types while exposing a more open interface to other MAITE-compliant components.

2 Data Types#

MAITE represents an individual data item as a tuple of:

  • input (i.e., image),

  • target (i.e., label), and

  • metadata (at the datum level)

and a batch of data items as a tuple of:

  • input batches,

  • target batches, and

  • metadata batches.

MAITE provides versions of Model, Dataset, DataLoader, Augmentation, and Metric protocols that correspond to different machine learning tasks (e.g. image classification, object detection) by parameterizing protocol interfaces on the particular input, target, and metadata types associated with that task.

2.1 Image Classification#

For image classification with Cl image classes, we have:

# define type to store an id of each datum (additional fields can be added by defining structurally-assignable TypedDict)
DatumMetadataType(TypedDict):
    id: str|int

InputType: TypeAlias = ArrayLike  # shape-(C, H, W) tensor with single image
TargetType: TypeAlias = ArrayLike  # shape-(Cl) tensor of one-hot encoded true class or predicted probabilities

InputBatchType: TypeAlias = Sequence[ArrayLike]  # element shape-(C, H, W) tensor of N images
TargetBatchType: TypeAlias = Sequence[ArrayLike]  # element shape-(Cl,)
DatumMetadataBatchType: TypeAlias = Sequence[DatumMetadataType]

Notes:

  • TargetType is used for both ground truth (coming from a dataset) and predictions (output from a model). So for a problem with 4 classes,

    • true label of class 2 would be one-hot encoded as [0, 0, 1, 0]

    • prediction from a model would be a vector of pseudo-probabilities, e.g., [0.1, 0.0, 0.7, 0.2]

  • InputType and InputBatchType are shown with shapes following PyTorch channels-first convention

These type aliases along with the versions of the various component protocols that use these types can be imported from maite.protocols.image_classification (if necessary):

# import protocol classes
from maite.protocols.image_classification import (
    Dataset,
    DataLoader,
    Model,
    Augmentation,
    Metric
)

# import type aliases
from maite.protocols.image_classification import (
    InputType,
    TargetType,
    DatumMetadataType,
    InputBatchType,
    TargetBatchType,
    DatumMetadataBatchType
)

Alternatively, image classification components and types can be accessed via the module directly:

import maite.protocols.image_classification as ic

# model: ic.Model = load_model(...)

2.2 Object Detection#

For object detection with D_i detections in an image i, we have:

# define type to store an id of each datum (additional fields can be added by defining structurally-assignable TypedDict)
DatumMetadataType(TypedDict):
    id: str|int

class ObjectDetectionTarget(Protocol):
    @property
    def boxes(self) -> ArrayLike: ...  # shape-(D_i, 4) tensor of bounding boxes w/format X0, Y0, X1, Y1

    @property
    def labels(self) -> ArrayLike: ... # shape-(D_i) tensor of labels for each box

    @property
    def scores(self) -> ArrayLike: ... # shape-(D_i) tensor of scores for each box (e.g., probabilities)

InputType: TypeAlias = ArrayLike  # shape-(C, H, W) tensor with single image
TargetType: TypeAlias = ObjectDetectionTarget

InputBatchType: TypeAlias = Sequence[ArrayLike]  # sequence of N ArrayLikes each of shape (C, H, W)
TargetBatchType: TypeAlias = Sequence[TargetType]  # sequence of object detection "target" objects
DatumMetadataBatchType: TypeAlias = Sequence[DatumMetadataType]

Notes:

  • ObjectDetectionTarget contains a single label and score per box

  • InputType and InputBatchType are shown with shapes following PyTorch channels-first convention

3 Models#

All models implement a __call__ method that takes the InputBatchType and produces the TargetBatchType appropriate for the given machine learning task.

import maite.protocols.image_classification as ic
print(ic.Model.__doc__)
A model protocol for the image classification ML subproblem.

Implementers must provide a __call__ method that operates on a batch of model
inputs (as Sequence[ArrayLike]) and returns a batch of model targets (as
`Sequence[ArrayLike])

Methods
-------

__call__(input_batch: Sequence[ArrayLike]) -> Sequence[ArrayLike]
    Make a model prediction for inputs in input batch. Input batch is expected to
    be Sequence[ArrayLike] with each element of shape (C, H, W).

Attributes
----------

metadata : ModelMetadata
    A typed dictionary containing at least an 'id' field of type str

Examples
--------

We create a multinomial logistic regression classifier for a CIFAR-10-like dataset
with 10 classes and shape-(3, 32, 32) images.

>>> import maite.protocols.image_classification as ic
>>> import numpy as np
>>> import numpy.typing as npt
>>> from maite.protocols import ArrayLike, ModelMetadata
>>> from typing import Sequence

Creating a MAITE-compliant model involves writing a __call__ method that takes a
batch of inputs and returns a batch of predictions (probabilities).

>>> class LinearClassifier:
...     def __init__(self) -> None:
...         # Set up required metadata attribute using the default ModelMetadata type,
...         # using class name for the ID
...         self.metadata: ModelMetadata = {"id": self.__class__.__name__}
...
...         # Initialize weights
...         rng = np.random.default_rng(12345678)
...         num_classes = 10
...         flattened_size = 3 * 32 * 32
...         self.weights = -0.2 + 0.4 * rng.random((flattened_size, num_classes))
...         self.bias = -0.2 + 0.4 * rng.random((1, num_classes))
...
...     def __call__(self, batch: Sequence[ArrayLike]) -> Sequence[npt.NDArray]:
...         # Convert each element in batch to ndarray, flatten,
...         # then combine into 4D array of shape-(N, C, H, W)
...         batch_np = np.vstack([np.asarray(x).flatten() for x in batch])
...
...         # Send input batch through model
...         out = batch_np @ self.weights + self.bias
...         out = np.exp(out) / np.sum(np.exp(out), axis=1, keepdims=True) # softmax
...
...         # Restructure to sequence of shape-(10,) probabilities
...         return [row for row in out]

We set up a test batch, instantiate the model, and apply it to the batch.

>>> batch_size = 8
>>> rng = np.random.default_rng(12345678)
>>> batch: Sequence[ArrayLike] = [-0.2 + 0.4 * rng.random((3, 32, 32)) for _ in range(batch_size)]
>>>
>>> model: ic.Model = LinearClassifier()
>>> out = model(batch)

We can now show the class probabilities returned by the model for each image in the batch.

>>> for probs in out:
...     print(np.round(probs, 2))
[0.16 0.1  0.16 0.14 0.04 0.02 0.06 0.04 0.17 0.1 ]
[0.21 0.16 0.04 0.07 0.08 0.05 0.09 0.03 0.18 0.09]
[0.15 0.11 0.13 0.11 0.09 0.09 0.07 0.04 0.19 0.02]
[0.04 0.08 0.14 0.07 0.12 0.2  0.11 0.06 0.14 0.04]
[0.03 0.08 0.06 0.05 0.17 0.18 0.09 0.03 0.12 0.19]
[0.09 0.04 0.1  0.03 0.32 0.05 0.07 0.04 0.15 0.09]
[0.15 0.05 0.1  0.05 0.11 0.14 0.04 0.08 0.08 0.2 ]
[0.11 0.11 0.08 0.11 0.08 0.05 0.24 0.03 0.08 0.12]

Note that when writing a Model implementer, return types may be narrower than the
return types promised by the protocol (npt.NDArray is a subtype of ArrayLike), but
the argument types must be at least as general as the argument types promised by the
protocol.
import maite.protocols.object_detection as od
print(od.Model.__doc__)
A model protocol for the object detection ML subproblem.

Implementers must provide a __call__ method that operates on a batch of model inputs
(as Sequence[ArrayLike]`s) and returns a batch of model targets (as
`Sequence[ObjectDetectionTarget])

Methods
-------

__call__(input_batch: Sequence[ArrayLike]) -> Sequence[ObjectDetectionTarget]
    Make a model prediction for inputs in input batch. Elements of input batch
    are expected in the shape (C, H, W).

Attributes
----------

metadata : ModelMetadata
    A typed dictionary containing at least an 'id' field of type str

Examples
--------
We create a simple MAITE-compliant object detection model, note it is a dummy model.

>>> from dataclasses import dataclass
>>> from typing import Sequence
>>> import numpy as np
>>> import maite.protocols.object_detection as od
>>> from maite.protocols import ModelMetadata

We define an object detection target dataclass as the output of the object detection model

>>> @dataclass
... class MyObjectDetectionTarget:
...     boxes: np.ndarray
...     labels: np.ndarray
...     scores: np.ndarray
...

Specify parameters that will be used to create a dummy dataset.

>>> N_DATAPOINTS = 2  # datapoints in dataset
>>> N_CLASSES = 5  # possible classes that can be detected
>>> C = 3  # number of color channels
>>> H = 32  # img height
>>> W = 32  # img width

Now create a batch of data to form the inputs of the MAITE's object detection model.

>>> simple_batch: list[np.ndarray] = [
...     np.random.rand(C, H, W) for _ in range(N_DATAPOINTS)
... ]

We define a simple object detection model, note here there is not an actual object detection
model. In the __call__ method, it just outputs the MyObjectDetectionTarget.

>>> class ObjectDetectionDummyModel:
...     metadata: ModelMetadata = {"id": "ObjectDetectionDummyModel"}
...
...     def __call__(
...         self, batch: od.InputBatchType
...     ) -> Sequence[MyObjectDetectionTarget]:
...         # For the simplicity, we don't provide an object detection model here, but the output from a model.
...         DETECTIONS_PER_IMG = (
...             2  # number of bounding boxes detections per image/datapoints
...         )
...         all_boxes = np.array(
...             [[1, 3, 5, 9], [2, 5, 8, 12], [4, 10, 8, 20], [3, 5, 6, 15]]
...         )  # all detection boxes for N_DATAPOINTS
...         all_predictions = list()
...         for datum_idx in range(N_DATAPOINTS):
...             boxes = all_boxes[datum_idx : datum_idx + DETECTIONS_PER_IMG]
...             labels = np.random.randint(N_CLASSES, size=DETECTIONS_PER_IMG)
...             scores = np.random.rand(DETECTIONS_PER_IMG)
...             predictions = MyObjectDetectionTarget(boxes, labels, scores)
...             all_predictions.append(predictions)
...         return all_predictions
...

We can instantiate this class and typehint it as a maite object detection model.
By using typehinting, we permit a static typechecker to verify protocol compliance.

>>> od_dummy_model: od.Model = ObjectDetectionDummyModel()
>>> od_dummy_model.metadata
{'id': 'ObjectDetectionDummyModel'}
>>> predictions = od_dummy_model(simple_batch)
>>> predictions  # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
[MyObjectDetectionTarget(boxes=array([[ 1,  3,  5,  9], [ 2,  5,  8, 12]]), labels=array([..., ...]), scores=array([..., ...])),
MyObjectDetectionTarget(boxes=array([[ 2,  5,  8, 12], [ 4, 10,  8, 20]]), labels=array([..., ...]), scores=array([..., ...]))]

4 Datasets and DataLoaders#

Datasets provide access to single data items and DataLoaders provide access to batches of data with the input, target, and metadata types corresponding to the given machine learning task.

print(ic.Dataset.__doc__)
A dataset protocol for image classification ML subproblem providing datum-level
data access.

Implementers must provide index lookup (via __getitem__(ind: int) method) and
support len (via __len__() method). Data elements looked up this way correspond
to individual examples (as opposed to batches).

Indexing into or iterating over an image_classification dataset returns a
tuple of types ArrayLike, ArrayLike, and DatumMetadataType.
These correspond to the model input type, model target type, and datum-level
metadata, respectively.

Methods
-------

__getitem__(ind: int) -> tuple[ArrayLike, ArrayLike, DatumMetadataType]
    Provide map-style access to dataset elements. Returned tuple elements
    correspond to model input type, model target type, and datum-specific metadata type,
    respectively.

__len__() -> int
    Return the number of data elements in the dataset.

Attributes
----------

metadata : DatasetMetadata
    A typed dictionary containing at least an 'id' field of type str

Examples
--------

We create a dummy set of data and use it to create a class that implements
this lightweight dataset protocol:

>>> import numpy as np
>>> from typing import Any, TypedDict
>>> from maite.protocols import ArrayLike, DatasetMetadata

Assume we have 5 classes, 10 datapoints, and 10 target labels, and that we want
to simply have an integer 'id' field in each datapoint's metadata:

>>> N_CLASSES: int = 5
>>> N_DATUM: int = 10
>>> images: list[np.ndarray] = [np.random.rand(3, 32, 16) for _ in range(N_DATUM)]
>>> targets: np.ndarray = np.eye(N_CLASSES)[np.random.choice(N_CLASSES, N_DATUM)]

We can type our datum metadata as a maite.protocols DatumMetadata, or define our
own TypedDict with additional fields

>>> class MyDatumMetadata(DatumMetadata):
...     hour_of_day: float
...
>>> datum_metadata = [ MyDatumMetadata(id = i, hour_of_day=np.random.rand()*24) for i in range(N_DATUM) ]

Constructing a compliant dataset just involves a simple wrapper that fetches
individual datapoints, where a datapoint is a single image, target, metadata 3-tuple.

>>> class ImageDataset:
...     def __init__(self,
...                  dataset_name: str,
...                  index2label: dict[int,str],
...                  images: list[np.ndarray],
...                  targets: np.ndarray,
...                  datum_metadata: list[MyDatumMetadata]):
...         self.images = images
...         self.targets = targets
...         self.metadata = DatasetMetadata({'id': dataset_name, 'index2label': index2label})
...         self._datum_metadata = datum_metadata
...     def __len__(self) -> int:
...         return len(images)
...     def __getitem__(self, ind: int) -> tuple[np.ndarray, np.ndarray, MyDatumMetadata]:
...         return self.images[ind], self.targets[ind], self._datum_metadata[ind]

We can instantiate this class and typehint it as an image_classification.Dataset.
By using typehinting, we permit a static typechecker to verify protocol compliance.

>>> from maite.protocols import image_classification as ic
>>> dataset: ic.Dataset = ImageDataset('a_dataset',
...                                    {i: f"class_name_{i}" for i in range(N_CLASSES)},
...                                    images, targets, datum_metadata)

Note that when writing a Dataset implementer, return types may be narrower than the
return types promised by the protocol (np.ndarray is a subtype of ArrayLike), but
the argument types must be at least as general as the argument types promised by the
protocol.
print(ic.DataLoader.__doc__)
A dataloader protocol for the image classification ML subproblem providing
batch-level data access.

Implementers must provide an iterable object (returning an iterator via the
__iter__ method) that yields tuples containing batches of data. These tuples
contain types Sequence[ArrayLike] (elements of shape (C, H, W)),
Sequence[ArrayLike] (elements shape (Cl, )), and Sequence[DatumMetadataType],
which correspond to model input batch, model target type batch, and a datum metadata batch.

Note: Unlike Dataset, this protocol does not require indexing support, only iterating.

Methods
-------

__iter__ -> Iterator[tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]]]
    Return an iterator over batches of data, where each batch contains a tuple of
    of model input batch (as Sequence[ArrayLike]), model target batch (as
    Sequence[ArrayLike]), and batched datum-level metadata
    (as Sequence[DatumMetadataType]), respectively.
print(od.DataLoader.__doc__)
A dataloader protocol for the object detection ML subproblem providing
batch-level data access.

Implementers must provide an iterable object (returning an iterator via the
__iter__ method) that yields tuples containing batches of data. These tuples
contain types Sequence[ArrayLike] (elements of shape (C, H, W)),
Sequence[ObjectDetectionTarget], and Sequence[DatumMetadataType],
which correspond to model input batch, model target batch, and a datum metadata batch.

Note: Unlike Dataset, this protocol does not require indexing support, only iterating.


Methods
-------

__iter__ -> Iterator[tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]]]
    Return an iterator over batches of data, where each batch contains a tuple of
    of model input batch (as Sequence[ArrayLike]), model target batch (as
    Sequence[ObjectDetectionTarget]), and batched datum-level metadata
    (as Sequence[DatumMetadataType]]), respectively.

5 Augmentations#

Augmentations take in and return a batch of data with the InputBatchType, TargetBatchType, and DatumMetadataBatchType types corresponding to the given machine learning task.

Augmentations can access the datum-level metadata associated with each data item to potentially tailor the augmentation to individual items. Augmentations can also associate new datum-level metadata with each data item, e.g., documenting aspects of the actual change that was applied (e.g., the actual rotation angle sampled from a range of possible angles).

print(ic.Augmentation.__doc__)
An augmentation protocol for the image classification subproblem.

An augmentation is expected to take a batch of data and return a modified version of
that batch. Implementers must provide a single method that takes and returns a
labeled data batch, where a labeled data batch is represented by a tuple of types
Sequence[ArrayLike] (with elements of shape (C, H, W)), Sequence[ArrayLike]
(with elements of shape (Cl, )), and Sequence[DatumMetadataType]. These correspond
to the model input batch type, model target batch type, and datum-level metadata
batch type, respectively.

Methods
-------

__call__(datum: tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]]) ->          tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]])
    Return a modified version of original data batch. A data batch is represented
    by a tuple of model input batch (as Sequence[ArrayLike] with elements of shape
    (C, H, W)), model target batch (as Sequence[ArrayLike] with elements of shape
    (Cl,)), and batch metadata (as Sequence[DatumMetadataType]), respectively.

Attributes
----------

metadata : AugmentationMetadata
    A typed dictionary containing at least an 'id' field of type str

Examples
--------

We can write an implementer of the augmentation class as either a function or a class.
The only requirement is that the object provide a __call__ method that takes objects
at least as general as the types promised in the protocol signature and return types
at least as specific.

>>> import copy
>>> import numpy as np
>>> from typing import Any
>>> from collections.abc import Sequence
>>> from maite.protocols import ArrayLike, DatumMetadata, AugmentationMetadata
>>>
>>> class EnrichedDatumMetadata(DatumMetadata):
...     new_key: int  # add a field to those already in DatumMetadata
...
>>> class ImageAugmentation:
...     def __init__(self, aug_name: str):
...         self.metadata: AugmentationMetadata = {'id': aug_name}
...     def __call__(
...         self,
...         data_batch: tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadata]]
...     ) -> tuple[Sequence[np.ndarray], Sequence[np.ndarray], Sequence[EnrichedDatumMetadata]]:
...         inputs, targets, mds = data_batch
...         # We copy data passed into the constructor to avoid mutating original inputs
...         # By using np.ndarray constructor, the static type-checker will let us treat
...         # generic ArrayLike as a more narrow return type
...         inputs_aug = [copy.copy(np.array(input)) for input in inputs]
...         targets_aug = [copy.copy(np.array(target)) for target in targets]
...         # Modify inputs_aug, targets_aug, or mds_aug as needed
...         # In this example, we just add a new metadata field
...         mds_aug = []
...         for i, md in enumerate(mds):
...             mds_aug.append(EnrichedDatumMetadata(**md, new_key=i))
...         return inputs_aug, targets_aug, mds_aug

We can typehint an instance of the above class as an Augmentation in the
image_classification domain:

>>> from maite.protocols import image_classification as ic
>>> im_aug: ic.Augmentation = ImageAugmentation(aug_name = 'an_augmentation')
print(od.Augmentation.__doc__)
An augmentation protocol for the object detection subproblem.

An augmentation is expected to take a batch of data and return a modified version of
that batch. Implementers must provide a single method that takes and returns a
labeled data batch, where a labeled data batch is represented by a tuple of types
Sequence[ArrayLike], Sequence[ObjectDetectionTarget], and Sequence[DatumMetadataType].
These correspond to the model input batch type, model target batch type, and datum-level
metadata batch type, respectively.

Methods
-------

__call__(datum: Tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]]) ->          Tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]]
    Return a modified version of original data batch. A data batch is represented
    by a tuple of model input batch (as Sequence ArrayLike with elements of shape
    (C, H, W)), model target batch (as Sequence[ObjectDetectionTarget]), and
    batch metadata (as Sequence[DatumMetadataType]), respectively.

Attributes
----------

metadata : AugmentationMetadata
    A typed dictionary containing at least an 'id' field of type str

Examples
--------

We can write an implementer of the augmentation class as either a function or a class.
The only requirement is that the object provide a __call__ method that takes objects
at least as general as the types promised in the protocol signature and return types
at least as specific.

We create a dummy set of data and use it to create a class the implements this
lightweight protocol and augments the data:

>>> import numpy as np
>>> np.random.seed(1)
>>> import copy
>>> from dataclasses import dataclass
>>> from typing import Any
>>> from maite.protocols import AugmentationMetadata, object_detection as od

First, we specify parameters that will be used to create the dummy dataset.

>>> N_DATAPOINTS = 3  # datapoints in dataset
>>> N_CLASSES = 2  # possible classes that can be detected
>>> C = 3  # number of color channels
>>> H = 10  # img height
>>> W = 10  # img width

Next, we create the input data to be used by the Augmentation. In this example, we
create the following batch of object detection data:

• xb is the input batch data. Our batch will include N_DATAPOINTS number of samples. Note
that we initialize all of the data to zeros in this example to demonstrate the augmentations
better.

• yb is the object detection target data, which in this example represents zero object
detections for each input datum (by having empty bounding boxes and class labels and scores).

• mdb is the associated metadata for each input datum.

>>> @dataclass
... class MyObjectDetectionTarget:
...     boxes: np.ndarray
...     labels: np.ndarray
...     scores: np.ndarray
>>> xb: od.InputBatchType = list(np.zeros((N_DATAPOINTS, C, H, W)))
>>> yb: od.TargetBatchType = list(
...     MyObjectDetectionTarget(boxes=np.empty(0), labels=np.empty(0), scores=np.empty(0))
...     for _ in range(N_DATAPOINTS)
... )
>>> mdb: od.DatumMetadataBatchType = list({"id": i} for i in range(N_DATAPOINTS))
>>> # Display the first datum in batch, first color channel, and only first 5 rows and cols
>>> np.array(xb[0])[0][:5, :5]
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]])

Now we create the Augmentation, which will apply random noise (rounded to 3 decimal
places) to the input data using the numpy.random.random and np.round functions.

>>> np_noise = lambda shape: np.round(np.random.random(shape), 3)
>>> class ImageAugmentation:
...     def __init__(self, aug_func: Any, metadata: AugmentationMetadata):
...         self.aug_func = aug_func
...         self.metadata = metadata
...     def __call__(
...         self,
...         batch: tuple[od.InputBatchType, od.TargetBatchType, od.DatumMetadataBatchType],
...     ) -> tuple[od.InputBatchType, od.TargetBatchType, od.DatumMetadataBatchType]:
...         xb, yb, mdb = batch
...         # Copy data passed into the constructor to avoid mutating original inputs
...         xb_aug = [copy.copy(input) for input in xb]
...         # Add random noise to the input batch data, xb
...         # (Note that all batch data dimensions (shapes) are the same in this example)
...         shape = np.array(xb[0]).shape
...         xb_aug = [x + self.aug_func(shape) for x in xb]
...         # Note that this example augmentation only affects inputs--not targets
...         return xb_aug, yb, mdb

We can typehint an instance of the above class as an Augmentation in the object
detection domain:

>>> noise: od.Augmentation = ImageAugmentation(np_noise, metadata={"id": "np_rand_noise"})

Now we can apply the noise augmentation to our 3-tuple batch of data. Recall that
our data was initialized to all zeros, so any non-zero values in the augmented data
is a result of the augmentation.

>>> xb_aug, yb_aug, mdb_aug = noise((xb, yb, mdb))
>>> # Display the first datum in batch, first color channel, and only first 5 rows and cols
>>> np.array(xb_aug[0])[0][:5, :5]
array([[0.417, 0.72 , 0.   , 0.302, 0.147],
       [0.419, 0.685, 0.204, 0.878, 0.027],
       [0.801, 0.968, 0.313, 0.692, 0.876],
       [0.098, 0.421, 0.958, 0.533, 0.692],
       [0.989, 0.748, 0.28 , 0.789, 0.103]])

6 Metrics#

The Metric protocol is inspired by the design of existing libraries like Torchmetrics and Torcheval. The update method operates on batches of predictions and truth labels by either caching them for later computation of the metric (via compute) or updating sufficient statistics in an online fashion.

print(ic.Metric.__doc__)
A metric protocol for the image classification ML subproblem.

A metric in this sense is expected to measure the level of agreement between model
predictions and ground-truth labels.

Methods
-------

update(preds: Sequence[ArrayLike], targets: Sequence[ArrayLike]) -> None
    Add predictions and targets to metric's cache for later calculation. Both
    preds and targets are expected to be sequences with elements of shape (Cl,).

compute() -> dict[str, Any]
    Compute metric value(s) for currently cached predictions and targets, returned as
    a dictionary.

reset() -> None
    Clear contents of current metric's cache of predictions and targets.

Attributes
----------

metadata : MetricMetadata
    A typed dictionary containing at least an 'id' field of type str

Examples
--------

Create a basic accuracy metric and test it on a small example dataset:

>>> from typing import Any, Sequence
>>> import numpy as np
>>> from maite.protocols import ArrayLike
>>> from maite.protocols import image_classification as ic

>>> class MyAccuracy:
...    metadata: MetricMetadata = {'id': 'Example Multiclass Accuracy'}
...
...    def __init__(self):
...        self._total = 0
...        self._correct = 0
...
...    def reset(self) -> None:
...        self._total = 0
...        self._correct = 0
...
...    def update(self, preds: Sequence[ArrayLike], targets: Sequence[ArrayLike]) -> None:
...        model_probs = [np.array(r) for r in preds]
...        true_onehot = [np.array(r) for r in targets]
...
...        # Stack into single array, convert to class indices
...        model_classes = np.vstack(model_probs).argmax(axis=1)
...        truth_classes = np.vstack(true_onehot).argmax(axis=1)
...
...        # Compare classes and update running counts
...        same = (model_classes == truth_classes)
...        self._total += len(same)
...        self._correct += same.sum()
...
...    def compute(self) -> dict[str, Any]:
...        if self._total > 0:
...            return {"accuracy": self._correct / self._total}
...        else:
...            raise Exception("No batches processed yet.")

Instantiate this class and typehint it as an image_classification.Metric.
By using typehinting, permits a static typechecker to check protocol compliance.

>>> accuracy: ic.Metric = MyAccuracy()

To use the metric call update() for each batch of predictions and truth values and call compute() to calculate the final metric values.

>>> # batch 1
>>> model_probs = [np.array([0.8, 0.1, 0.0, 0.1]), np.array([0.1, 0.2, 0.6, 0.1])] # predicted classes: 0, 2
>>> true_onehot = [np.array([1.0, 0.0, 0.0, 0.0]), np.array([0.0, 1.0, 0.0, 0.0])] # true classes: 0, 1
>>> accuracy.update(model_probs, true_onehot)
>>> print(accuracy.compute())
{'accuracy': 0.5}
>>>
>>> # batch 2
>>> model_probs = [np.array([0.1, 0.1, 0.7, 0.1]), np.array([0.0, 0.1, 0.0, 0.9])] # predicted classes: 2, 3
>>> true_onehot = [np.array([0.0, 0.0, 1.0, 0.0]), np.array([0.0, 0.0, 0.0, 1.0])] # true classes: 2, 3
>>> accuracy.update(model_probs, true_onehot)
>>>
>>> print(accuracy.compute())
{'accuracy': 0.75}
print(od.Metric.__doc__)
A metric protocol for the object detection ML subproblem.

A metric in this sense is expected to measure the level of agreement between model
predictions and ground-truth labels.

Methods
-------

update(preds: Sequence[ObjectDetectionTarget], targets: Sequence[ObjectDetectionTarget]) -> None
     Add predictions and targets to metric's cache for later calculation.

compute() -> dict[str, Any]
     Compute metric value(s) for currently cached predictions and targets, returned as
     a dictionary.

reset() -> None
    Clear contents of current metric's cache of predictions and targets.

Attributes
----------

metadata : MetricMetadata
    A typed dictionary containing at least an 'id' field of type str

Examples
--------

Below, we write and test a class that implements the Metric protocol for object detection.
For simplicity, the metric will compute the intersection over union (IoU) averaged over
all predicted and associated ground truth boxes for a single image, and then take the mean
over the per-image means.

Note that when writing a Metric implementer, return types may be narrower than the
return types promised by the protocol, but the argument types must be at least as
general as the argument types promised by the protocol.

>>> from dataclasses import dataclass
>>> from maite.protocols import ArrayLike, MetricMetadata
>>> from typing import Any, Sequence
>>> import maite.protocols.object_detection as od
>>> import numpy as np

>>> class MyIoUMetric:
...
...     def __init__(self, id: str):
...         self.pred_boxes = []  # elements correspond to predicted boxes in single image
...         self.target_boxes = []  # elements correspond to ground truth boxes in single image
...         # Store provided id for this metric instance
...         self.metadata = MetricMetadata(
...             id=id
...         )
...
...     def reset(self) -> None:
...         self.pred_boxes = []
...         self.target_boxes = []
...
...     def update(
...         self,
...         pred_batch: Sequence[od.ObjectDetectionTarget],
...         target_batch: Sequence[od.ObjectDetectionTarget],
...     ) -> None:
...         self.pred_boxes.extend(pred_batch)
...         self.target_boxes.extend(target_batch)
...
...     @staticmethod
...     def iou_vec(boxes_a: ArrayLike, boxes_b: ArrayLike) -> np.ndarray:
...         # Break up points into separate columns
...         x0a, y0a, x1a, y1a = np.split(boxes_a, 4, axis=1)
...         x0b, y0b, x1b, y1b = np.split(boxes_b, 4, axis=1)
...         # Calculate intersections
...         xi_0, yi_0 = np.split(
...             np.maximum(np.append(x0a, y0a, axis=1), np.append(x0b, y0b, axis=1)),
...             2,
...             axis=1,
...         )
...         xi_1, yi_1 = np.split(
...             np.minimum(np.append(x1a, y1a, axis=1), np.append(x1b, y1b, axis=1)),
...             2,
...             axis=1,
...         )
...         ints: np.ndarray = np.maximum(0, xi_1 - xi_0) * np.maximum(0, yi_1 - yi_0)
...         # Calculate unions (as sum of areas minus their intersection)
...         unions: np.ndarray = (
...             (x1a - x0a) * (y1a - y0a)
...             + (x1b - x0b) * (y1b - y0b)
...             - (xi_1 - xi_0) * (yi_1 - yi_0)
...         )
...         return ints / unions
...
...     def compute(self) -> dict[str, Any]:
...         mean_iou_by_img: list[float] = []
...         for pred_box, tgt_box in zip(self.pred_boxes, self.target_boxes):
...             single_img_ious = self.iou_vec(pred_box.boxes, tgt_box.boxes)
...             mean_iou_by_img.append(float(np.mean(single_img_ious)))
...         return {"mean_iou": np.mean(np.array(mean_iou_by_img))}
...

Now we can instantiate our IoU Metric class:

>>> iou_metric: od.Metric = MyIoUMetric(id="IoUMetric")

To use the metric, we populate two lists that encode predicted object detections
and ground truth object detections for a single image. (Ordinarily, predictions
would be produced by a model.)

>>> prediction_boxes: list[tuple[int, int, int, int]] = [
...     (1, 1, 12, 12),
...     (100, 100, 120, 120),
...     (180, 180, 270, 270),
... ]

>>> target_boxes: list[tuple[int, int, int, int]] = [
...     (1, 1, 10, 10),
...     (100, 100, 120, 120),
...     (200, 200, 300, 300),
... ]

The MAITE Metric protocol requires the pred_batch and target_batch arguments to the
update method to be assignable to Sequence[ObjectDetectionTarget] (where ObjectDetectionTarget
encodes detections in a single image). We define an implementation of ObjectDetectionTarget and use it
to pass ground truth and predicted detections.

>>> @dataclass
... class ObjectDetectionTargetImpl:
...     boxes: np.ndarray
...     labels: np.ndarray
...     scores: np.ndarray

>>> num_boxes = len(target_boxes)
>>> fake_labels = np.random.randint(0, 9, num_boxes)
>>> fake_scores = np.zeros(num_boxes)
>>> pred_batch = [
...     ObjectDetectionTargetImpl(
...         boxes=np.array(prediction_boxes), labels=fake_labels, scores=fake_scores
...     )
... ]
>>> target_batch: Sequence[ObjectDetectionTargetImpl] = [
...     ObjectDetectionTargetImpl(
...         boxes=np.array(target_boxes), labels=fake_labels, scores=fake_scores
...     )
... ]

Finally, we call update using this one-element batch, compute the metric value, and print it:

>>> iou_metric.update(pred_batch, target_batch)
>>> print(iou_metric.compute())
{'mean_iou': 0.6802112029384757}

7 Workflows#

MAITE provides high-level utilities for common workflows such as evaluate and predict. They can be called with either Datasets or DataLoaders, and with optional Augmentation.

The evaluate function can optionally return the model predictions and (potentially-augmented) data batches used during inference.

The predict function returns the model predictions and (potentially-augmented) data batches used during inference, essentially calling evaluate with a dummy metric.

from maite.workflows import evaluate, predict
# we can also import from object_detection module
# where the function call signature is the same
print(evaluate.__doc__)
Evaluate a model's performance on data according to some metric with optional augmentation.

Some data source (either a dataloader or a dataset) must be provided
or an InvalidArgument exception is raised.

Parameters
----------
model : SomeModel
    Maite Model object.

metric : SomeMetric | None, (default=None)
    Compatible maite Metric.

dataloader : SomeDataloader | None, (default=None)
    Compatible maite dataloader.

dataset : SomeDataset | None, (default=None)
    Compatible maite dataset.

batch_size : int, (default=1)
    Batch size for use with dataset (ignored if dataset=None).

augmentation : SomeAugmentation | None, (default=None)
    Compatible maite augmentation.

return_augmented_data : bool, (default=False)
    Set to True to return post-augmentation data as a function output.

return_preds : bool, (default=False)
    Set to True to return raw predictions as a function output.

Returns
-------
tuple[dict[str, Any], Sequence[TargetType], Sequence[tuple[InputBatchType, TargetBatchType, DatumMetadataBatchType]]]
    Tuple of returned metric value, sequence of model predictions, and
    sequence of data batch tuples fed to the model during inference. The actual
    types represented by InputBatchType, TargetBatchType, and DatumMetadataBatchType will vary
    by the domain of the components provided as input arguments (e.g. image
    classification or object detection.)
    Note that the second and third return arguments will be empty if
    return_augmented_data is False or return_preds is False, respectively.
print(predict.__doc__)
Make predictions for a given model & data source with optional augmentation.

Some data source (either a dataloader or a dataset) must be provided
or an InvalidArgument exception is raised.

Parameters
----------
model : SomeModel
    Maite Model object.

dataloader : SomeDataloader | None, (default=None)
    Compatible maite dataloader.

dataset : SomeDataset | None, (default=None)
    Compatible maite dataset.

batch_size : int, (default=1)
    Batch size for use with dataset (ignored if dataset=None).

augmentation : SomeAugmentation | None, (default=None)
    Compatible maite augmentation.

Returns
-------
tuple[Sequence[SomeTargetBatchType], Sequence[tuple[SomeInputBatchType, SomeTargetBatchType, SomeMetadataBatchType]],
    A tuple of the predictions (as a sequence of batches) and a sequence
    of tuples containing the information associated with each batch.