::
Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
SPDX-License-Identifier: MIT
Overview of MAITE Protocols
===========================
MAITE provides protocols for the following AI components:
- models
- datasets
- dataloaders
- augmentations
- metrics
MAITE protocols specify expected interfaces of these components (i.e, a
minimal set of required attributes, methods, and method type signatures)
to promote interoperability in test and evaluation (T&E). This enables
the creation of higher-level workflows (e.g., an ``evaluate`` utility)
that can interact with any components that conform to the protocols.
1 Concept: Bridging ArrayLikes
------------------------------
MAITE uses a type called ``ArrayLike`` (following NumPy’s
`interoperability
approach `__)
that helps components that natively use different flavors of tensors
(e.g., NumPy ndarray, PyTorch Tensor, JAX ndarray) work together.
In this example, the functions “type narrow” from ``ArrayLike`` to the
type they want to work with internally. Note that this doesn’t
necessarily require a conversion depending on the actual input type.
.. code:: ipython3
import numpy as np
import torch
from maite.protocols import ArrayLike
def my_numpy_fn(x: ArrayLike) -> np.ndarray:
arr = np.asarray(x)
# ...
return arr
def my_torch_fn(x: ArrayLike) -> torch.Tensor:
tensor = torch.as_tensor(x)
# ...
return tensor
# can apply NumPy function to PyTorch Tensor
np_out = my_numpy_fn(torch.rand(2, 3))
# can apply PyTorch function to NumPy array
torch_out = my_torch_fn(np.random.rand(2, 3))
# note: no performance hit from conversion when all `ArrayLike`s are from same library
# or when can share the same underlying memory
torch_out = my_torch_fn(torch.rand(2, 3))
By using bridging, we MAITE can permit implementers of the protocol to
internally interact with their own types while exposing a more open
interface to other MAITE-compliant components.
2 Data Types
------------
MAITE represents an *individual* data item as a tuple of:
- input (i.e., image),
- target (i.e., label), and
- metadata (at the datum level)
and a *batch* of data items as a tuple of:
- input batches,
- target batches, and
- metadata batches.
MAITE provides versions of ``Model``, ``Dataset``, ``DataLoader``,
``Augmentation``, and ``Metric`` protocols that correspond to different
machine learning tasks (e.g. image classification, object detection) by
parameterizing protocol interfaces on the particular input, target, and
metadata types associated with that task.
2.1 Image Classification
~~~~~~~~~~~~~~~~~~~~~~~~
For image classification with ``Cl`` image classes, we have:
.. code:: python
# define type to store an id of each datum (additional fields can be added by defining structurally-assignable TypedDict)
DatumMetadataType(TypedDict):
id: str|int
InputType: TypeAlias = ArrayLike # shape-(C, H, W) tensor with single image
TargetType: TypeAlias = ArrayLike # shape-(Cl) tensor of one-hot encoded true class or predicted probabilities
InputBatchType: TypeAlias = Sequence[ArrayLike] # element shape-(C, H, W) tensor of N images
TargetBatchType: TypeAlias = Sequence[ArrayLike] # element shape-(Cl,)
DatumMetadataBatchType: TypeAlias = Sequence[DatumMetadataType]
Notes:
- ``TargetType`` is used for both ground truth (coming from a dataset)
and predictions (output from a model). So for a problem with 4
classes,
- true label of class 2 would be one-hot encoded as ``[0, 0, 1, 0]``
- prediction from a model would be a vector of pseudo-probabilities,
e.g., ``[0.1, 0.0, 0.7, 0.2]``
- ``InputType`` and ``InputBatchType`` are shown with shapes following
PyTorch channels-first convention
These type aliases along with the versions of the various component
protocols that use these types can be imported from
``maite.protocols.image_classification`` (if necessary):
.. code:: ipython3
# import protocol classes
from maite.protocols.image_classification import (
Dataset,
DataLoader,
Model,
Augmentation,
Metric
)
# import type aliases
from maite.protocols.image_classification import (
InputType,
TargetType,
DatumMetadataType,
InputBatchType,
TargetBatchType,
DatumMetadataBatchType
)
Alternatively, image classification components and types can be accessed
via the module directly:
.. code:: ipython3
import maite.protocols.image_classification as ic
# model: ic.Model = load_model(...)
2.2 Object Detection
~~~~~~~~~~~~~~~~~~~~
For object detection with ``D_i`` detections in an image ``i``, we have:
.. code:: python
# define type to store an id of each datum (additional fields can be added by defining structurally-assignable TypedDict)
DatumMetadataType(TypedDict):
id: str|int
class ObjectDetectionTarget(Protocol):
@property
def boxes(self) -> ArrayLike: ... # shape-(D_i, 4) tensor of bounding boxes w/format X0, Y0, X1, Y1
@property
def labels(self) -> ArrayLike: ... # shape-(D_i) tensor of labels for each box
@property
def scores(self) -> ArrayLike: ... # shape-(D_i) tensor of scores for each box (e.g., probabilities)
InputType: TypeAlias = ArrayLike # shape-(C, H, W) tensor with single image
TargetType: TypeAlias = ObjectDetectionTarget
InputBatchType: TypeAlias = Sequence[ArrayLike] # sequence of N ArrayLikes each of shape (C, H, W)
TargetBatchType: TypeAlias = Sequence[TargetType] # sequence of object detection "target" objects
DatumMetadataBatchType: TypeAlias = Sequence[DatumMetadataType]
Notes:
- ``ObjectDetectionTarget`` contains a single label and score per box
- ``InputType`` and ``InputBatchType`` are shown with shapes following
PyTorch channels-first convention
3 Models
--------
All models implement a ``__call__`` method that takes the
``InputBatchType`` and produces the ``TargetBatchType`` appropriate for
the given machine learning task.
.. code:: ipython3
import maite.protocols.image_classification as ic
print(ic.Model.__doc__)
.. parsed-literal::
A model protocol for the image classification ML subproblem.
Implementers must provide a `__call__` method that operates on a batch of model
inputs (as `Sequence[ArrayLike]) and returns a batch of model targets (as
`Sequence[ArrayLike]`)
Methods
-------
__call__(input_batch: Sequence[ArrayLike]) -> Sequence[ArrayLike]
Make a model prediction for inputs in input batch. Input batch is expected to
be `Sequence[ArrayLike]` with each element of shape `(C, H, W)`.
Attributes
----------
metadata : ModelMetadata
A typed dictionary containing at least an 'id' field of type str
Examples
--------
We create a multinomial logistic regression classifier for a CIFAR-10-like dataset
with 10 classes and shape-(3, 32, 32) images.
>>> import maite.protocols.image_classification as ic
>>> import numpy as np
>>> import numpy.typing as npt
>>> from maite.protocols import ArrayLike, ModelMetadata
>>> from typing import Sequence
Creating a MAITE-compliant model involves writing a `__call__` method that takes a
batch of inputs and returns a batch of predictions (probabilities).
>>> class LinearClassifier:
... def __init__(self) -> None:
... # Set up required metadata attribute using the default `ModelMetadata` type,
... # using class name for the ID
... self.metadata: ModelMetadata = {"id": self.__class__.__name__}
...
... # Initialize weights
... rng = np.random.default_rng(12345678)
... num_classes = 10
... flattened_size = 3 * 32 * 32
... self.weights = -0.2 + 0.4 * rng.random((flattened_size, num_classes))
... self.bias = -0.2 + 0.4 * rng.random((1, num_classes))
...
... def __call__(self, batch: Sequence[ArrayLike]) -> Sequence[npt.NDArray]:
... # Convert each element in batch to ndarray, flatten,
... # then combine into 4D array of shape-(N, C, H, W)
... batch_np = np.vstack([np.asarray(x).flatten() for x in batch])
...
... # Send input batch through model
... out = batch_np @ self.weights + self.bias
... out = np.exp(out) / np.sum(np.exp(out), axis=1, keepdims=True) # softmax
...
... # Restructure to sequence of shape-(10,) probabilities
... return [row for row in out]
We set up a test batch, instantiate the model, and apply it to the batch.
>>> batch_size = 8
>>> rng = np.random.default_rng(12345678)
>>> batch: Sequence[ArrayLike] = [-0.2 + 0.4 * rng.random((3, 32, 32)) for _ in range(batch_size)]
>>>
>>> model: ic.Model = LinearClassifier()
>>> out = model(batch)
We can now show the class probabilities returned by the model for each image in the batch.
>>> for probs in out:
... print(np.round(probs, 2))
[0.16 0.1 0.16 0.14 0.04 0.02 0.06 0.04 0.17 0.1 ]
[0.21 0.16 0.04 0.07 0.08 0.05 0.09 0.03 0.18 0.09]
[0.15 0.11 0.13 0.11 0.09 0.09 0.07 0.04 0.19 0.02]
[0.04 0.08 0.14 0.07 0.12 0.2 0.11 0.06 0.14 0.04]
[0.03 0.08 0.06 0.05 0.17 0.18 0.09 0.03 0.12 0.19]
[0.09 0.04 0.1 0.03 0.32 0.05 0.07 0.04 0.15 0.09]
[0.15 0.05 0.1 0.05 0.11 0.14 0.04 0.08 0.08 0.2 ]
[0.11 0.11 0.08 0.11 0.08 0.05 0.24 0.03 0.08 0.12]
Note that when writing a Model implementer, return types may be narrower than the
return types promised by the protocol (npt.NDArray is a subtype of ArrayLike), but
the argument types must be at least as general as the argument types promised by the
protocol.
.. code:: ipython3
import maite.protocols.object_detection as od
print(od.Model.__doc__)
.. parsed-literal::
A model protocol for the object detection ML subproblem.
Implementers must provide a `__call__` method that operates on a batch of model inputs
(as `Sequence[ArrayLike]`s) and returns a batch of model targets (as
`Sequence[ObjectDetectionTarget]`)
Methods
-------
__call__(input_batch: Sequence[ArrayLike]) -> Sequence[ObjectDetectionTarget]
Make a model prediction for inputs in input batch. Elements of input batch
are expected in the shape `(C, H, W)`.
Attributes
----------
metadata : ModelMetadata
A typed dictionary containing at least an 'id' field of type str
Examples
--------
We create a simple MAITE-compliant object detection model, note it is a dummy model.
>>> from dataclasses import dataclass
>>> from typing import Sequence
>>> import numpy as np
>>> import maite.protocols.object_detection as od
>>> from maite.protocols import ModelMetadata
We define an object detection target dataclass as the output of the object detection model
>>> @dataclass
... class MyObjectDetectionTarget:
... boxes: np.ndarray
... labels: np.ndarray
... scores: np.ndarray
...
Specify parameters that will be used to create a dummy dataset.
>>> N_DATAPOINTS = 2 # datapoints in dataset
>>> N_CLASSES = 5 # possible classes that can be detected
>>> C = 3 # number of color channels
>>> H = 32 # img height
>>> W = 32 # img width
Now create a batch of data to form the inputs of the MAITE's object detection model.
>>> simple_batch: list[np.ndarray] = [
... np.random.rand(C, H, W) for _ in range(N_DATAPOINTS)
... ]
We define a simple object detection model, note here there is not an actual object detection
model. In the __call__ method, it just outputs the MyObjectDetectionTarget.
>>> class ObjectDetectionDummyModel:
... metadata: ModelMetadata = {"id": "ObjectDetectionDummyModel"}
...
... def __call__(
... self, batch: od.InputBatchType
... ) -> Sequence[MyObjectDetectionTarget]:
... # For the simplicity, we don't provide an object detection model here, but the output from a model.
... DETECTIONS_PER_IMG = (
... 2 # number of bounding boxes detections per image/datapoints
... )
... all_boxes = np.array(
... [[1, 3, 5, 9], [2, 5, 8, 12], [4, 10, 8, 20], [3, 5, 6, 15]]
... ) # all detection boxes for N_DATAPOINTS
... all_predictions = list()
... for datum_idx in range(N_DATAPOINTS):
... boxes = all_boxes[datum_idx : datum_idx + DETECTIONS_PER_IMG]
... labels = np.random.randint(N_CLASSES, size=DETECTIONS_PER_IMG)
... scores = np.random.rand(DETECTIONS_PER_IMG)
... predictions = MyObjectDetectionTarget(boxes, labels, scores)
... all_predictions.append(predictions)
... return all_predictions
...
We can instantiate this class and typehint it as a maite object detection model.
By using typehinting, we permit a static typechecker to verify protocol compliance.
>>> od_dummy_model: od.Model = ObjectDetectionDummyModel()
>>> od_dummy_model.metadata
{'id': 'ObjectDetectionDummyModel'}
>>> predictions = od_dummy_model(simple_batch)
>>> predictions # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
[MyObjectDetectionTarget(boxes=array([[ 1, 3, 5, 9], [ 2, 5, 8, 12]]), labels=array([..., ...]), scores=array([..., ...])),
MyObjectDetectionTarget(boxes=array([[ 2, 5, 8, 12], [ 4, 10, 8, 20]]), labels=array([..., ...]), scores=array([..., ...]))]
4 Datasets and DataLoaders
--------------------------
``Dataset``\ s provide access to single data items and ``DataLoader``\ s
provide access to batches of data with the input, target, and metadata
types corresponding to the given machine learning task.
.. code:: ipython3
print(ic.Dataset.__doc__)
.. parsed-literal::
A dataset protocol for image classification ML subproblem providing datum-level
data access.
Implementers must provide index lookup (via `__getitem__(ind: int)` method) and
support `len` (via `__len__()` method). Data elements looked up this way correspond
to individual examples (as opposed to batches).
Indexing into or iterating over an image_classification dataset returns a
`tuple` of types `ArrayLike`, `ArrayLike`, and `DatumMetadataType`.
These correspond to the model input type, model target type, and datum-level
metadata, respectively.
Methods
-------
__getitem__(ind: int) -> tuple[ArrayLike, ArrayLike, DatumMetadataType]
Provide map-style access to dataset elements. Returned tuple elements
correspond to model input type, model target type, and datum-specific metadata type,
respectively.
__len__() -> int
Return the number of data elements in the dataset.
Attributes
----------
metadata : DatasetMetadata
A typed dictionary containing at least an 'id' field of type str
Examples
--------
We create a dummy set of data and use it to create a class that implements
this lightweight dataset protocol:
>>> import numpy as np
>>> from typing import Any, TypedDict
>>> from maite.protocols import ArrayLike, DatasetMetadata
Assume we have 5 classes, 10 datapoints, and 10 target labels, and that we want
to simply have an integer 'id' field in each datapoint's metadata:
>>> N_CLASSES: int = 5
>>> N_DATUM: int = 10
>>> images: list[np.ndarray] = [np.random.rand(3, 32, 16) for _ in range(N_DATUM)]
>>> targets: np.ndarray = np.eye(N_CLASSES)[np.random.choice(N_CLASSES, N_DATUM)]
We can type our datum metadata as a maite.protocols DatumMetadata, or define our
own TypedDict with additional fields
>>> class MyDatumMetadata(DatumMetadata):
... hour_of_day: float
...
>>> datum_metadata = [ MyDatumMetadata(id = i, hour_of_day=np.random.rand()*24) for i in range(N_DATUM) ]
Constructing a compliant dataset just involves a simple wrapper that fetches
individual datapoints, where a datapoint is a single image, target, metadata 3-tuple.
>>> class ImageDataset:
... def __init__(self,
... dataset_name: str,
... index2label: dict[int,str],
... images: list[np.ndarray],
... targets: np.ndarray,
... datum_metadata: list[MyDatumMetadata]):
... self.images = images
... self.targets = targets
... self.metadata = DatasetMetadata({'id': dataset_name, 'index2label': index2label})
... self._datum_metadata = datum_metadata
... def __len__(self) -> int:
... return len(images)
... def __getitem__(self, ind: int) -> tuple[np.ndarray, np.ndarray, MyDatumMetadata]:
... return self.images[ind], self.targets[ind], self._datum_metadata[ind]
We can instantiate this class and typehint it as an image_classification.Dataset.
By using typehinting, we permit a static typechecker to verify protocol compliance.
>>> from maite.protocols import image_classification as ic
>>> dataset: ic.Dataset = ImageDataset('a_dataset',
... {i: f"class_name_{i}" for i in range(N_CLASSES)},
... images, targets, datum_metadata)
Note that when writing a Dataset implementer, return types may be narrower than the
return types promised by the protocol (np.ndarray is a subtype of ArrayLike), but
the argument types must be at least as general as the argument types promised by the
protocol.
.. code:: ipython3
print(ic.DataLoader.__doc__)
.. parsed-literal::
A dataloader protocol for the image classification ML subproblem providing
batch-level data access.
Implementers must provide an iterable object (returning an iterator via the
`__iter__` method) that yields tuples containing batches of data. These tuples
contain types `Sequence[ArrayLike]` (elements of shape `(C, H, W)`),
`Sequence[ArrayLike]` (elements shape `(Cl, )`), and `Sequence[DatumMetadataType]`,
which correspond to model input batch, model target type batch, and a datum metadata batch.
Note: Unlike Dataset, this protocol does not require indexing support, only iterating.
Methods
-------
__iter__ -> Iterator[tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]]]
Return an iterator over batches of data, where each batch contains a tuple of
of model input batch (as `Sequence[ArrayLike]`), model target batch (as
`Sequence[ArrayLike]`), and batched datum-level metadata
(as `Sequence[DatumMetadataType]`), respectively.
.. code:: ipython3
print(od.DataLoader.__doc__)
.. parsed-literal::
A dataloader protocol for the object detection ML subproblem providing
batch-level data access.
Implementers must provide an iterable object (returning an iterator via the
`__iter__` method) that yields tuples containing batches of data. These tuples
contain types `Sequence[ArrayLike]` (elements of shape `(C, H, W)`),
`Sequence[ObjectDetectionTarget]`, and `Sequence[DatumMetadataType]`,
which correspond to model input batch, model target batch, and a datum metadata batch.
Note: Unlike Dataset, this protocol does not require indexing support, only iterating.
Methods
-------
__iter__ -> Iterator[tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]]]
Return an iterator over batches of data, where each batch contains a tuple of
of model input batch (as `Sequence[ArrayLike]`), model target batch (as
`Sequence[ObjectDetectionTarget]`), and batched datum-level metadata
(as `Sequence[DatumMetadataType]]`), respectively.
5 Augmentations
---------------
``Augmentation``\ s take in and return a batch of data with the
``InputBatchType``, ``TargetBatchType``, and ``DatumMetadataBatchType``
types corresponding to the given machine learning task.
Augmentations can access the datum-level metadata associated with each
data item to potentially tailor the augmentation to individual items.
Augmentations can also associate new datum-level metadata with each data
item, e.g., documenting aspects of the actual change that was applied
(e.g., the actual rotation angle sampled from a range of possible
angles).
.. code:: ipython3
print(ic.Augmentation.__doc__)
.. parsed-literal::
An augmentation protocol for the image classification subproblem.
An augmentation is expected to take a batch of data and return a modified version of
that batch. Implementers must provide a single method that takes and returns a
labeled data batch, where a labeled data batch is represented by a tuple of types
`Sequence[ArrayLike]` (with elements of shape `(C, H, W)`), `Sequence[ArrayLike]`
(with elements of shape `(Cl, )`), and `Sequence[DatumMetadataType]`. These correspond
to the model input batch type, model target batch type, and datum-level metadata
batch type, respectively.
Methods
-------
__call__(datum: tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]]) -> tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]])
Return a modified version of original data batch. A data batch is represented
by a tuple of model input batch (as `Sequence[ArrayLike]` with elements of shape
`(C, H, W)`), model target batch (as `Sequence[ArrayLike]` with elements of shape
`(Cl,)`), and batch metadata (as `Sequence[DatumMetadataType]`), respectively.
Attributes
----------
metadata : AugmentationMetadata
A typed dictionary containing at least an 'id' field of type str
Examples
--------
We can write an implementer of the augmentation class as either a function or a class.
The only requirement is that the object provide a __call__ method that takes objects
at least as general as the types promised in the protocol signature and return types
at least as specific.
>>> import copy
>>> import numpy as np
>>> from typing import Any
>>> from collections.abc import Sequence
>>> from maite.protocols import ArrayLike, DatumMetadata, AugmentationMetadata
>>>
>>> class EnrichedDatumMetadata(DatumMetadata):
... new_key: int # add a field to those already in DatumMetadata
...
>>> class ImageAugmentation:
... def __init__(self, aug_name: str):
... self.metadata: AugmentationMetadata = {'id': aug_name}
... def __call__(
... self,
... data_batch: tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadata]]
... ) -> tuple[Sequence[np.ndarray], Sequence[np.ndarray], Sequence[EnrichedDatumMetadata]]:
... inputs, targets, mds = data_batch
... # We copy data passed into the constructor to avoid mutating original inputs
... # By using np.ndarray constructor, the static type-checker will let us treat
... # generic ArrayLike as a more narrow return type
... inputs_aug = [copy.copy(np.array(input)) for input in inputs]
... targets_aug = [copy.copy(np.array(target)) for target in targets]
... # Modify inputs_aug, targets_aug, or mds_aug as needed
... # In this example, we just add a new metadata field
... mds_aug = []
... for i, md in enumerate(mds):
... mds_aug.append(EnrichedDatumMetadata(**md, new_key=i))
... return inputs_aug, targets_aug, mds_aug
We can typehint an instance of the above class as an Augmentation in the
image_classification domain:
>>> from maite.protocols import image_classification as ic
>>> im_aug: ic.Augmentation = ImageAugmentation(aug_name = 'an_augmentation')
.. code:: ipython3
print(od.Augmentation.__doc__)
.. parsed-literal::
An augmentation protocol for the object detection subproblem.
An augmentation is expected to take a batch of data and return a modified version of
that batch. Implementers must provide a single method that takes and returns a
labeled data batch, where a labeled data batch is represented by a tuple of types
`Sequence[ArrayLike]`, `Sequence[ObjectDetectionTarget]`, and `Sequence[DatumMetadataType]`.
These correspond to the model input batch type, model target batch type, and datum-level
metadata batch type, respectively.
Methods
-------
__call__(datum: Tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]]) -> Tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]]
Return a modified version of original data batch. A data batch is represented
by a tuple of model input batch (as `Sequence ArrayLike` with elements of shape
`(C, H, W)`), model target batch (as `Sequence[ObjectDetectionTarget]`), and
batch metadata (as `Sequence[DatumMetadataType]`), respectively.
Attributes
----------
metadata : AugmentationMetadata
A typed dictionary containing at least an 'id' field of type str
Examples
--------
We can write an implementer of the augmentation class as either a function or a class.
The only requirement is that the object provide a `__call__` method that takes objects
at least as general as the types promised in the protocol signature and return types
at least as specific.
We create a dummy set of data and use it to create a class the implements this
lightweight protocol and augments the data:
>>> import numpy as np
>>> np.random.seed(1)
>>> import copy
>>> from dataclasses import dataclass
>>> from typing import Any
>>> from maite.protocols import AugmentationMetadata, object_detection as od
First, we specify parameters that will be used to create the dummy dataset.
>>> N_DATAPOINTS = 3 # datapoints in dataset
>>> N_CLASSES = 2 # possible classes that can be detected
>>> C = 3 # number of color channels
>>> H = 10 # img height
>>> W = 10 # img width
Next, we create the input data to be used by the Augmentation. In this example, we
create the following batch of object detection data:
• `xb` is the input batch data. Our batch will include `N_DATAPOINTS` number of samples. Note
that we initialize all of the data to zeros in this example to demonstrate the augmentations
better.
• `yb` is the object detection target data, which in this example represents zero object
detections for each input datum (by having empty bounding boxes and class labels and scores).
• `mdb` is the associated metadata for each input datum.
>>> @dataclass
... class MyObjectDetectionTarget:
... boxes: np.ndarray
... labels: np.ndarray
... scores: np.ndarray
>>> xb: od.InputBatchType = list(np.zeros((N_DATAPOINTS, C, H, W)))
>>> yb: od.TargetBatchType = list(
... MyObjectDetectionTarget(boxes=np.empty(0), labels=np.empty(0), scores=np.empty(0))
... for _ in range(N_DATAPOINTS)
... )
>>> mdb: od.DatumMetadataBatchType = list({"id": i} for i in range(N_DATAPOINTS))
>>> # Display the first datum in batch, first color channel, and only first 5 rows and cols
>>> np.array(xb[0])[0][:5, :5]
array([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
Now we create the Augmentation, which will apply random noise (rounded to 3 decimal
places) to the input data using the `numpy.random.random` and `np.round` functions.
>>> np_noise = lambda shape: np.round(np.random.random(shape), 3)
>>> class ImageAugmentation:
... def __init__(self, aug_func: Any, metadata: AugmentationMetadata):
... self.aug_func = aug_func
... self.metadata = metadata
... def __call__(
... self,
... batch: tuple[od.InputBatchType, od.TargetBatchType, od.DatumMetadataBatchType],
... ) -> tuple[od.InputBatchType, od.TargetBatchType, od.DatumMetadataBatchType]:
... xb, yb, mdb = batch
... # Copy data passed into the constructor to avoid mutating original inputs
... xb_aug = [copy.copy(input) for input in xb]
... # Add random noise to the input batch data, xb
... # (Note that all batch data dimensions (shapes) are the same in this example)
... shape = np.array(xb[0]).shape
... xb_aug = [x + self.aug_func(shape) for x in xb]
... # Note that this example augmentation only affects inputs--not targets
... return xb_aug, yb, mdb
We can typehint an instance of the above class as an Augmentation in the object
detection domain:
>>> noise: od.Augmentation = ImageAugmentation(np_noise, metadata={"id": "np_rand_noise"})
Now we can apply the `noise` augmentation to our 3-tuple batch of data. Recall that
our data was initialized to all zeros, so any non-zero values in the augmented data
is a result of the augmentation.
>>> xb_aug, yb_aug, mdb_aug = noise((xb, yb, mdb))
>>> # Display the first datum in batch, first color channel, and only first 5 rows and cols
>>> np.array(xb_aug[0])[0][:5, :5]
array([[0.417, 0.72 , 0. , 0.302, 0.147],
[0.419, 0.685, 0.204, 0.878, 0.027],
[0.801, 0.968, 0.313, 0.692, 0.876],
[0.098, 0.421, 0.958, 0.533, 0.692],
[0.989, 0.748, 0.28 , 0.789, 0.103]])
6 Metrics
---------
The ``Metric`` protocol is inspired by the design of existing libraries
like Torchmetrics and Torcheval. The ``update`` method operates on
batches of predictions and truth labels by either caching them for later
computation of the metric (via ``compute``) or updating sufficient
statistics in an online fashion.
.. code:: ipython3
print(ic.Metric.__doc__)
.. parsed-literal::
A metric protocol for the image classification ML subproblem.
A metric in this sense is expected to measure the level of agreement between model
predictions and ground-truth labels.
Methods
-------
update(preds: Sequence[ArrayLike], targets: Sequence[ArrayLike]) -> None
Add predictions and targets to metric's cache for later calculation. Both
preds and targets are expected to be sequences with elements of shape `(Cl,)`.
compute() -> dict[str, Any]
Compute metric value(s) for currently cached predictions and targets, returned as
a dictionary.
reset() -> None
Clear contents of current metric's cache of predictions and targets.
Attributes
----------
metadata : MetricMetadata
A typed dictionary containing at least an 'id' field of type str
Examples
--------
Create a basic accuracy metric and test it on a small example dataset:
>>> from typing import Any, Sequence
>>> import numpy as np
>>> from maite.protocols import ArrayLike
>>> from maite.protocols import image_classification as ic
>>> class MyAccuracy:
... metadata: MetricMetadata = {'id': 'Example Multiclass Accuracy'}
...
... def __init__(self):
... self._total = 0
... self._correct = 0
...
... def reset(self) -> None:
... self._total = 0
... self._correct = 0
...
... def update(self, preds: Sequence[ArrayLike], targets: Sequence[ArrayLike]) -> None:
... model_probs = [np.array(r) for r in preds]
... true_onehot = [np.array(r) for r in targets]
...
... # Stack into single array, convert to class indices
... model_classes = np.vstack(model_probs).argmax(axis=1)
... truth_classes = np.vstack(true_onehot).argmax(axis=1)
...
... # Compare classes and update running counts
... same = (model_classes == truth_classes)
... self._total += len(same)
... self._correct += same.sum()
...
... def compute(self) -> dict[str, Any]:
... if self._total > 0:
... return {"accuracy": self._correct / self._total}
... else:
... raise Exception("No batches processed yet.")
Instantiate this class and typehint it as an image_classification.Metric.
By using typehinting, permits a static typechecker to check protocol compliance.
>>> accuracy: ic.Metric = MyAccuracy()
To use the metric call update() for each batch of predictions and truth values and call compute() to calculate the final metric values.
>>> # batch 1
>>> model_probs = [np.array([0.8, 0.1, 0.0, 0.1]), np.array([0.1, 0.2, 0.6, 0.1])] # predicted classes: 0, 2
>>> true_onehot = [np.array([1.0, 0.0, 0.0, 0.0]), np.array([0.0, 1.0, 0.0, 0.0])] # true classes: 0, 1
>>> accuracy.update(model_probs, true_onehot)
>>> print(accuracy.compute())
{'accuracy': 0.5}
>>>
>>> # batch 2
>>> model_probs = [np.array([0.1, 0.1, 0.7, 0.1]), np.array([0.0, 0.1, 0.0, 0.9])] # predicted classes: 2, 3
>>> true_onehot = [np.array([0.0, 0.0, 1.0, 0.0]), np.array([0.0, 0.0, 0.0, 1.0])] # true classes: 2, 3
>>> accuracy.update(model_probs, true_onehot)
>>>
>>> print(accuracy.compute())
{'accuracy': 0.75}
.. code:: ipython3
print(od.Metric.__doc__)
.. parsed-literal::
A metric protocol for the object detection ML subproblem.
A metric in this sense is expected to measure the level of agreement between model
predictions and ground-truth labels.
Methods
-------
update(preds: Sequence[ObjectDetectionTarget], targets: Sequence[ObjectDetectionTarget]) -> None
Add predictions and targets to metric's cache for later calculation.
compute() -> dict[str, Any]
Compute metric value(s) for currently cached predictions and targets, returned as
a dictionary.
reset() -> None
Clear contents of current metric's cache of predictions and targets.
Attributes
----------
metadata : MetricMetadata
A typed dictionary containing at least an 'id' field of type str
Examples
--------
Below, we write and test a class that implements the Metric protocol for object detection.
For simplicity, the metric will compute the intersection over union (IoU) averaged over
all predicted and associated ground truth boxes for a single image, and then take the mean
over the per-image means.
Note that when writing a `Metric` implementer, return types may be narrower than the
return types promised by the protocol, but the argument types must be at least as
general as the argument types promised by the protocol.
>>> from dataclasses import dataclass
>>> from maite.protocols import ArrayLike, MetricMetadata
>>> from typing import Any, Sequence
>>> import maite.protocols.object_detection as od
>>> import numpy as np
>>> class MyIoUMetric:
...
... def __init__(self, id: str):
... self.pred_boxes = [] # elements correspond to predicted boxes in single image
... self.target_boxes = [] # elements correspond to ground truth boxes in single image
... # Store provided id for this metric instance
... self.metadata = MetricMetadata(
... id=id
... )
...
... def reset(self) -> None:
... self.pred_boxes = []
... self.target_boxes = []
...
... def update(
... self,
... pred_batch: Sequence[od.ObjectDetectionTarget],
... target_batch: Sequence[od.ObjectDetectionTarget],
... ) -> None:
... self.pred_boxes.extend(pred_batch)
... self.target_boxes.extend(target_batch)
...
... @staticmethod
... def iou_vec(boxes_a: ArrayLike, boxes_b: ArrayLike) -> np.ndarray:
... # Break up points into separate columns
... x0a, y0a, x1a, y1a = np.split(boxes_a, 4, axis=1)
... x0b, y0b, x1b, y1b = np.split(boxes_b, 4, axis=1)
... # Calculate intersections
... xi_0, yi_0 = np.split(
... np.maximum(np.append(x0a, y0a, axis=1), np.append(x0b, y0b, axis=1)),
... 2,
... axis=1,
... )
... xi_1, yi_1 = np.split(
... np.minimum(np.append(x1a, y1a, axis=1), np.append(x1b, y1b, axis=1)),
... 2,
... axis=1,
... )
... ints: np.ndarray = np.maximum(0, xi_1 - xi_0) * np.maximum(0, yi_1 - yi_0)
... # Calculate unions (as sum of areas minus their intersection)
... unions: np.ndarray = (
... (x1a - x0a) * (y1a - y0a)
... + (x1b - x0b) * (y1b - y0b)
... - (xi_1 - xi_0) * (yi_1 - yi_0)
... )
... return ints / unions
...
... def compute(self) -> dict[str, Any]:
... mean_iou_by_img: list[float] = []
... for pred_box, tgt_box in zip(self.pred_boxes, self.target_boxes):
... single_img_ious = self.iou_vec(pred_box.boxes, tgt_box.boxes)
... mean_iou_by_img.append(float(np.mean(single_img_ious)))
... return {"mean_iou": np.mean(np.array(mean_iou_by_img))}
...
Now we can instantiate our IoU Metric class:
>>> iou_metric: od.Metric = MyIoUMetric(id="IoUMetric")
To use the metric, we populate two lists that encode predicted object detections
and ground truth object detections for a single image. (Ordinarily, predictions
would be produced by a model.)
>>> prediction_boxes: list[tuple[int, int, int, int]] = [
... (1, 1, 12, 12),
... (100, 100, 120, 120),
... (180, 180, 270, 270),
... ]
>>> target_boxes: list[tuple[int, int, int, int]] = [
... (1, 1, 10, 10),
... (100, 100, 120, 120),
... (200, 200, 300, 300),
... ]
The MAITE Metric protocol requires the `pred_batch` and `target_batch` arguments to the
`update` method to be assignable to Sequence[ObjectDetectionTarget] (where ObjectDetectionTarget
encodes detections in a single image). We define an implementation of ObjectDetectionTarget and use it
to pass ground truth and predicted detections.
>>> @dataclass
... class ObjectDetectionTargetImpl:
... boxes: np.ndarray
... labels: np.ndarray
... scores: np.ndarray
>>> num_boxes = len(target_boxes)
>>> fake_labels = np.random.randint(0, 9, num_boxes)
>>> fake_scores = np.zeros(num_boxes)
>>> pred_batch = [
... ObjectDetectionTargetImpl(
... boxes=np.array(prediction_boxes), labels=fake_labels, scores=fake_scores
... )
... ]
>>> target_batch: Sequence[ObjectDetectionTargetImpl] = [
... ObjectDetectionTargetImpl(
... boxes=np.array(target_boxes), labels=fake_labels, scores=fake_scores
... )
... ]
Finally, we call `update` using this one-element batch, compute the metric value, and print it:
>>> iou_metric.update(pred_batch, target_batch)
>>> print(iou_metric.compute())
{'mean_iou': 0.6802112029384757}
7 Workflows
-----------
MAITE provides high-level utilities for common workflows such as
``evaluate`` and ``predict``. They can be called with either
``Dataset``\ s or ``DataLoader``\ s, and with optional ``Augmentation``.
The ``evaluate`` function can optionally return the model predictions
and (potentially-augmented) data batches used during inference.
The ``predict`` function returns the model predictions and
(potentially-augmented) data batches used during inference, essentially
calling ``evaluate`` with a dummy metric.
.. code:: ipython3
from maite.workflows import evaluate, predict
# we can also import from object_detection module
# where the function call signature is the same
.. code:: ipython3
print(evaluate.__doc__)
.. parsed-literal::
Evaluate a model's performance on data according to some metric with optional augmentation.
Some data source (either a dataloader or a dataset) must be provided
or an InvalidArgument exception is raised.
Parameters
----------
model : SomeModel
Maite Model object.
metric : SomeMetric | None, (default=None)
Compatible maite Metric.
dataloader : SomeDataloader | None, (default=None)
Compatible maite dataloader.
dataset : SomeDataset | None, (default=None)
Compatible maite dataset.
batch_size : int, (default=1)
Batch size for use with dataset (ignored if dataset=None).
augmentation : SomeAugmentation | None, (default=None)
Compatible maite augmentation.
return_augmented_data : bool, (default=False)
Set to True to return post-augmentation data as a function output.
return_preds : bool, (default=False)
Set to True to return raw predictions as a function output.
Returns
-------
tuple[dict[str, Any], Sequence[TargetType], Sequence[tuple[InputBatchType, TargetBatchType, DatumMetadataBatchType]]]
Tuple of returned metric value, sequence of model predictions, and
sequence of data batch tuples fed to the model during inference. The actual
types represented by InputBatchType, TargetBatchType, and DatumMetadataBatchType will vary
by the domain of the components provided as input arguments (e.g. image
classification or object detection.)
Note that the second and third return arguments will be empty if
return_augmented_data is False or return_preds is False, respectively.
.. code:: ipython3
print(predict.__doc__)
.. parsed-literal::
Make predictions for a given model & data source with optional augmentation.
Some data source (either a dataloader or a dataset) must be provided
or an InvalidArgument exception is raised.
Parameters
----------
model : SomeModel
Maite Model object.
dataloader : SomeDataloader | None, (default=None)
Compatible maite dataloader.
dataset : SomeDataset | None, (default=None)
Compatible maite dataset.
batch_size : int, (default=1)
Batch size for use with dataset (ignored if dataset=None).
augmentation : SomeAugmentation | None, (default=None)
Compatible maite augmentation.
Returns
-------
tuple[Sequence[SomeTargetBatchType], Sequence[tuple[SomeInputBatchType, SomeTargetBatchType, SomeMetadataBatchType]],
A tuple of the predictions (as a sequence of batches) and a sequence
of tuples containing the information associated with each batch.