Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY<br/>
Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).<br/>
SPDX-License-Identifier: MIT
Overview of MAITE Protocols#
MAITE provides protocols for the following AI components:
models
datasets
dataloaders
augmentations
metrics
MAITE protocols specify expected interfaces of these components (i.e, a
minimal set of required attributes, methods, and method type signatures)
to promote interoperability in test and evaluation (T&E). This enables
the creation of higher-level workflows (e.g., an evaluate
utility)
that can interact with any components that conform to the protocols.
1 Concept: Bridging ArrayLikes#
MAITE uses a type called ArrayLike
(following NumPy’s
interoperability
approach)
that helps components that natively use different flavors of tensors
(e.g., NumPy ndarray, PyTorch Tensor, JAX ndarray) work together.
In this example, the functions “type narrow” from ArrayLike
to the
type they want to work with internally. Note that this doesn’t
necessarily require a conversion depending on the actual input type.
import numpy as np
import torch
from maite.protocols import ArrayLike
def my_numpy_fn(x: ArrayLike) -> np.ndarray:
arr = np.asarray(x)
# ...
return arr
def my_torch_fn(x: ArrayLike) -> torch.Tensor:
tensor = torch.as_tensor(x)
# ...
return tensor
# can apply NumPy function to PyTorch Tensor
np_out = my_numpy_fn(torch.rand(2, 3))
# can apply PyTorch function to NumPy array
torch_out = my_torch_fn(np.random.rand(2, 3))
# note: no performance hit from conversion when all `ArrayLike`s are from same library
# or when can share the same underlying memory
torch_out = my_torch_fn(torch.rand(2, 3))
By using bridging, we MAITE can permit implementers of the protocol to internally interact with their own types while exposing a more open interface to other MAITE-compliant components.
2 Data Types#
MAITE represents an individual data item as a tuple of:
input (i.e., image),
target (i.e., label), and
metadata (at the datum level)
and a batch of data items as a tuple of:
input batches,
target batches, and
metadata batches.
MAITE provides versions of Model
, Dataset
, DataLoader
,
Augmentation
, and Metric
protocols that correspond to different
machine learning tasks (e.g. image classification, object detection) by
parameterizing protocol interfaces on the particular input, target, and
metadata types associated with that task.
2.1 Image Classification#
For image classification with Cl
image classes, we have:
# define type to store an id of each datum (additional fields can be added by defining structurally-assignable TypedDict)
DatumMetadataType(TypedDict):
id: str|int
InputType: TypeAlias = ArrayLike # shape-(C, H, W) tensor with single image
TargetType: TypeAlias = ArrayLike # shape-(Cl) tensor of one-hot encoded true class or predicted probabilities
InputBatchType: TypeAlias = Sequence[ArrayLike] # element shape-(C, H, W) tensor of N images
TargetBatchType: TypeAlias = Sequence[ArrayLike] # element shape-(Cl,)
DatumMetadataBatchType: TypeAlias = Sequence[DatumMetadataType]
Notes:
TargetType
is used for both ground truth (coming from a dataset) and predictions (output from a model). So for a problem with 4 classes,true label of class 2 would be one-hot encoded as
[0, 0, 1, 0]
prediction from a model would be a vector of pseudo-probabilities, e.g.,
[0.1, 0.0, 0.7, 0.2]
InputType
andInputBatchType
are shown with shapes following PyTorch channels-first convention
These type aliases along with the versions of the various component
protocols that use these types can be imported from
maite.protocols.image_classification
(if necessary):
# import protocol classes
from maite.protocols.image_classification import (
Dataset,
DataLoader,
Model,
Augmentation,
Metric
)
# import type aliases
from maite.protocols.image_classification import (
InputType,
TargetType,
DatumMetadataType,
InputBatchType,
TargetBatchType,
DatumMetadataBatchType
)
Alternatively, image classification components and types can be accessed via the module directly:
import maite.protocols.image_classification as ic
# model: ic.Model = load_model(...)
2.2 Object Detection#
For object detection with D_i
detections in an image i
, we have:
# define type to store an id of each datum (additional fields can be added by defining structurally-assignable TypedDict)
DatumMetadataType(TypedDict):
id: str|int
class ObjectDetectionTarget(Protocol):
@property
def boxes(self) -> ArrayLike: ... # shape-(D_i, 4) tensor of bounding boxes w/format X0, Y0, X1, Y1
@property
def labels(self) -> ArrayLike: ... # shape-(D_i) tensor of labels for each box
@property
def scores(self) -> ArrayLike: ... # shape-(D_i) tensor of scores for each box (e.g., probabilities)
InputType: TypeAlias = ArrayLike # shape-(C, H, W) tensor with single image
TargetType: TypeAlias = ObjectDetectionTarget
InputBatchType: TypeAlias = Sequence[ArrayLike] # sequence of N ArrayLikes each of shape (C, H, W)
TargetBatchType: TypeAlias = Sequence[TargetType] # sequence of object detection "target" objects
DatumMetadataBatchType: TypeAlias = Sequence[DatumMetadataType]
Notes:
ObjectDetectionTarget
contains a single label and score per boxInputType
andInputBatchType
are shown with shapes following PyTorch channels-first convention
3 Models#
All models implement a __call__
method that takes the
InputBatchType
and produces the TargetBatchType
appropriate for
the given machine learning task.
import maite.protocols.image_classification as ic
print(ic.Model.__doc__)
A model protocol for the image classification ML subproblem. Implementers must provide a__call__
method that operates on a batch of model inputs (asSequence[ArrayLike]) and returns a batch of model targets (as `Sequence[ArrayLike]
) Methods ------- __call__(input_batch: Sequence[ArrayLike]) -> Sequence[ArrayLike] Make a model prediction for inputs in input batch. Input batch is expected to beSequence[ArrayLike]
with each element of shape(C, H, W)
. Attributes ---------- metadata : ModelMetadata A typed dictionary containing at least an 'id' field of type str Examples -------- We create a multinomial logistic regression classifier for a CIFAR-10-like dataset with 10 classes and shape-(3, 32, 32) images. >>> import maite.protocols.image_classification as ic >>> import numpy as np >>> import numpy.typing as npt >>> from maite.protocols import ArrayLike, ModelMetadata >>> from typing import Sequence Creating a MAITE-compliant model involves writing a__call__
method that takes a batch of inputs and returns a batch of predictions (probabilities). >>> class LinearClassifier: ... def __init__(self) -> None: ... # Set up required metadata attribute using the defaultModelMetadata
type, ... # using class name for the ID ... self.metadata: ModelMetadata = {"id": self.__class__.__name__} ... ... # Initialize weights ... rng = np.random.default_rng(12345678) ... num_classes = 10 ... flattened_size = 3 * 32 * 32 ... self.weights = -0.2 + 0.4 * rng.random((flattened_size, num_classes)) ... self.bias = -0.2 + 0.4 * rng.random((1, num_classes)) ... ... def __call__(self, batch: Sequence[ArrayLike]) -> Sequence[npt.NDArray]: ... # Convert each element in batch to ndarray, flatten, ... # then combine into 4D array of shape-(N, C, H, W) ... batch_np = np.vstack([np.asarray(x).flatten() for x in batch]) ... ... # Send input batch through model ... out = batch_np @ self.weights + self.bias ... out = np.exp(out) / np.sum(np.exp(out), axis=1, keepdims=True) # softmax ... ... # Restructure to sequence of shape-(10,) probabilities ... return [row for row in out] We set up a test batch, instantiate the model, and apply it to the batch. >>> batch_size = 8 >>> rng = np.random.default_rng(12345678) >>> batch: Sequence[ArrayLike] = [-0.2 + 0.4 * rng.random((3, 32, 32)) for _ in range(batch_size)] >>> >>> model: ic.Model = LinearClassifier() >>> out = model(batch) We can now show the class probabilities returned by the model for each image in the batch. >>> for probs in out: ... print(np.round(probs, 2)) [0.16 0.1 0.16 0.14 0.04 0.02 0.06 0.04 0.17 0.1 ] [0.21 0.16 0.04 0.07 0.08 0.05 0.09 0.03 0.18 0.09] [0.15 0.11 0.13 0.11 0.09 0.09 0.07 0.04 0.19 0.02] [0.04 0.08 0.14 0.07 0.12 0.2 0.11 0.06 0.14 0.04] [0.03 0.08 0.06 0.05 0.17 0.18 0.09 0.03 0.12 0.19] [0.09 0.04 0.1 0.03 0.32 0.05 0.07 0.04 0.15 0.09] [0.15 0.05 0.1 0.05 0.11 0.14 0.04 0.08 0.08 0.2 ] [0.11 0.11 0.08 0.11 0.08 0.05 0.24 0.03 0.08 0.12] Note that when writing a Model implementer, return types may be narrower than the return types promised by the protocol (npt.NDArray is a subtype of ArrayLike), but the argument types must be at least as general as the argument types promised by the protocol.
import maite.protocols.object_detection as od
print(od.Model.__doc__)
A model protocol for the object detection ML subproblem. Implementers must provide a__call__
method that operates on a batch of model inputs (asSequence[ArrayLike]`s) and returns a batch of model targets (as `Sequence[ObjectDetectionTarget]
) Methods ------- __call__(input_batch: Sequence[ArrayLike]) -> Sequence[ObjectDetectionTarget] Make a model prediction for inputs in input batch. Elements of input batch are expected in the shape(C, H, W)
. Attributes ---------- metadata : ModelMetadata A typed dictionary containing at least an 'id' field of type str Examples -------- We create a simple MAITE-compliant object detection model, note it is a dummy model. >>> from dataclasses import dataclass >>> from typing import Sequence >>> import numpy as np >>> import maite.protocols.object_detection as od >>> from maite.protocols import ModelMetadata We define an object detection target dataclass as the output of the object detection model >>> @dataclass ... class MyObjectDetectionTarget: ... boxes: np.ndarray ... labels: np.ndarray ... scores: np.ndarray ... Specify parameters that will be used to create a dummy dataset. >>> N_DATAPOINTS = 2 # datapoints in dataset >>> N_CLASSES = 5 # possible classes that can be detected >>> C = 3 # number of color channels >>> H = 32 # img height >>> W = 32 # img width Now create a batch of data to form the inputs of the MAITE's object detection model. >>> simple_batch: list[np.ndarray] = [ ... np.random.rand(C, H, W) for _ in range(N_DATAPOINTS) ... ] We define a simple object detection model, note here there is not an actual object detection model. In the __call__ method, it just outputs the MyObjectDetectionTarget. >>> class ObjectDetectionDummyModel: ... metadata: ModelMetadata = {"id": "ObjectDetectionDummyModel"} ... ... def __call__( ... self, batch: od.InputBatchType ... ) -> Sequence[MyObjectDetectionTarget]: ... # For the simplicity, we don't provide an object detection model here, but the output from a model. ... DETECTIONS_PER_IMG = ( ... 2 # number of bounding boxes detections per image/datapoints ... ) ... all_boxes = np.array( ... [[1, 3, 5, 9], [2, 5, 8, 12], [4, 10, 8, 20], [3, 5, 6, 15]] ... ) # all detection boxes for N_DATAPOINTS ... all_predictions = list() ... for datum_idx in range(N_DATAPOINTS): ... boxes = all_boxes[datum_idx : datum_idx + DETECTIONS_PER_IMG] ... labels = np.random.randint(N_CLASSES, size=DETECTIONS_PER_IMG) ... scores = np.random.rand(DETECTIONS_PER_IMG) ... predictions = MyObjectDetectionTarget(boxes, labels, scores) ... all_predictions.append(predictions) ... return all_predictions ... We can instantiate this class and typehint it as a maite object detection model. By using typehinting, we permit a static typechecker to verify protocol compliance. >>> od_dummy_model: od.Model = ObjectDetectionDummyModel() >>> od_dummy_model.metadata {'id': 'ObjectDetectionDummyModel'} >>> predictions = od_dummy_model(simple_batch) >>> predictions # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS [MyObjectDetectionTarget(boxes=array([[ 1, 3, 5, 9], [ 2, 5, 8, 12]]), labels=array([..., ...]), scores=array([..., ...])), MyObjectDetectionTarget(boxes=array([[ 2, 5, 8, 12], [ 4, 10, 8, 20]]), labels=array([..., ...]), scores=array([..., ...]))]
4 Datasets and DataLoaders#
Dataset
s provide access to single data items and DataLoader
s
provide access to batches of data with the input, target, and metadata
types corresponding to the given machine learning task.
print(ic.Dataset.__doc__)
A dataset protocol for image classification ML subproblem providing datum-level data access. Implementers must provide index lookup (via__getitem__(ind: int)
method) and supportlen
(via__len__()
method). Data elements looked up this way correspond to individual examples (as opposed to batches). Indexing into or iterating over an image_classification dataset returns atuple
of typesArrayLike
,ArrayLike
, andDatumMetadataType
. These correspond to the model input type, model target type, and datum-level metadata, respectively. Methods ------- __getitem__(ind: int) -> tuple[ArrayLike, ArrayLike, DatumMetadataType] Provide map-style access to dataset elements. Returned tuple elements correspond to model input type, model target type, and datum-specific metadata type, respectively. __len__() -> int Return the number of data elements in the dataset. Attributes ---------- metadata : DatasetMetadata A typed dictionary containing at least an 'id' field of type str Examples -------- We create a dummy set of data and use it to create a class that implements this lightweight dataset protocol: >>> import numpy as np >>> from typing import Any, TypedDict >>> from maite.protocols import ArrayLike, DatasetMetadata Assume we have 5 classes, 10 datapoints, and 10 target labels, and that we want to simply have an integer 'id' field in each datapoint's metadata: >>> N_CLASSES: int = 5 >>> N_DATUM: int = 10 >>> images: list[np.ndarray] = [np.random.rand(3, 32, 16) for _ in range(N_DATUM)] >>> targets: np.ndarray = np.eye(N_CLASSES)[np.random.choice(N_CLASSES, N_DATUM)] We can type our datum metadata as a maite.protocols DatumMetadata, or define our own TypedDict with additional fields >>> class MyDatumMetadata(DatumMetadata): ... hour_of_day: float ... >>> datum_metadata = [ MyDatumMetadata(id = i, hour_of_day=np.random.rand()*24) for i in range(N_DATUM) ] Constructing a compliant dataset just involves a simple wrapper that fetches individual datapoints, where a datapoint is a single image, target, metadata 3-tuple. >>> class ImageDataset: ... def __init__(self, ... dataset_name: str, ... index2label: dict[int,str], ... images: list[np.ndarray], ... targets: np.ndarray, ... datum_metadata: list[MyDatumMetadata]): ... self.images = images ... self.targets = targets ... self.metadata = DatasetMetadata({'id': dataset_name, 'index2label': index2label}) ... self._datum_metadata = datum_metadata ... def __len__(self) -> int: ... return len(images) ... def __getitem__(self, ind: int) -> tuple[np.ndarray, np.ndarray, MyDatumMetadata]: ... return self.images[ind], self.targets[ind], self._datum_metadata[ind] We can instantiate this class and typehint it as an image_classification.Dataset. By using typehinting, we permit a static typechecker to verify protocol compliance. >>> from maite.protocols import image_classification as ic >>> dataset: ic.Dataset = ImageDataset('a_dataset', ... {i: f"class_name_{i}" for i in range(N_CLASSES)}, ... images, targets, datum_metadata) Note that when writing a Dataset implementer, return types may be narrower than the return types promised by the protocol (np.ndarray is a subtype of ArrayLike), but the argument types must be at least as general as the argument types promised by the protocol.
print(ic.DataLoader.__doc__)
A dataloader protocol for the image classification ML subproblem providing batch-level data access. Implementers must provide an iterable object (returning an iterator via the__iter__
method) that yields tuples containing batches of data. These tuples contain typesSequence[ArrayLike]
(elements of shape(C, H, W)
),Sequence[ArrayLike]
(elements shape(Cl, )
), andSequence[DatumMetadataType]
, which correspond to model input batch, model target type batch, and a datum metadata batch. Note: Unlike Dataset, this protocol does not require indexing support, only iterating. Methods ------- __iter__ -> Iterator[tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]]] Return an iterator over batches of data, where each batch contains a tuple of of model input batch (asSequence[ArrayLike]
), model target batch (asSequence[ArrayLike]
), and batched datum-level metadata (asSequence[DatumMetadataType]
), respectively.
print(od.DataLoader.__doc__)
A dataloader protocol for the object detection ML subproblem providing batch-level data access. Implementers must provide an iterable object (returning an iterator via the__iter__
method) that yields tuples containing batches of data. These tuples contain typesSequence[ArrayLike]
(elements of shape(C, H, W)
),Sequence[ObjectDetectionTarget]
, andSequence[DatumMetadataType]
, which correspond to model input batch, model target batch, and a datum metadata batch. Note: Unlike Dataset, this protocol does not require indexing support, only iterating. Methods ------- __iter__ -> Iterator[tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]]] Return an iterator over batches of data, where each batch contains a tuple of of model input batch (asSequence[ArrayLike]
), model target batch (asSequence[ObjectDetectionTarget]
), and batched datum-level metadata (asSequence[DatumMetadataType]]
), respectively.
5 Augmentations#
Augmentation
s take in and return a batch of data with the
InputBatchType
, TargetBatchType
, and DatumMetadataBatchType
types corresponding to the given machine learning task.
Augmentations can access the datum-level metadata associated with each data item to potentially tailor the augmentation to individual items. Augmentations can also associate new datum-level metadata with each data item, e.g., documenting aspects of the actual change that was applied (e.g., the actual rotation angle sampled from a range of possible angles).
print(ic.Augmentation.__doc__)
An augmentation protocol for the image classification subproblem. An augmentation is expected to take a batch of data and return a modified version of that batch. Implementers must provide a single method that takes and returns a labeled data batch, where a labeled data batch is represented by a tuple of typesSequence[ArrayLike]
(with elements of shape(C, H, W)
),Sequence[ArrayLike]
(with elements of shape(Cl, )
), andSequence[DatumMetadataType]
. These correspond to the model input batch type, model target batch type, and datum-level metadata batch type, respectively. Methods ------- __call__(datum: tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]]) -> tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadataType]]) Return a modified version of original data batch. A data batch is represented by a tuple of model input batch (asSequence[ArrayLike]
with elements of shape(C, H, W)
), model target batch (asSequence[ArrayLike]
with elements of shape(Cl,)
), and batch metadata (asSequence[DatumMetadataType]
), respectively. Attributes ---------- metadata : AugmentationMetadata A typed dictionary containing at least an 'id' field of type str Examples -------- We can write an implementer of the augmentation class as either a function or a class. The only requirement is that the object provide a __call__ method that takes objects at least as general as the types promised in the protocol signature and return types at least as specific. >>> import copy >>> import numpy as np >>> from typing import Any >>> from collections.abc import Sequence >>> from maite.protocols import ArrayLike, DatumMetadata, AugmentationMetadata >>> >>> class EnrichedDatumMetadata(DatumMetadata): ... new_key: int # add a field to those already in DatumMetadata ... >>> class ImageAugmentation: ... def __init__(self, aug_name: str): ... self.metadata: AugmentationMetadata = {'id': aug_name} ... def __call__( ... self, ... data_batch: tuple[Sequence[ArrayLike], Sequence[ArrayLike], Sequence[DatumMetadata]] ... ) -> tuple[Sequence[np.ndarray], Sequence[np.ndarray], Sequence[EnrichedDatumMetadata]]: ... inputs, targets, mds = data_batch ... # We copy data passed into the constructor to avoid mutating original inputs ... # By using np.ndarray constructor, the static type-checker will let us treat ... # generic ArrayLike as a more narrow return type ... inputs_aug = [copy.copy(np.array(input)) for input in inputs] ... targets_aug = [copy.copy(np.array(target)) for target in targets] ... # Modify inputs_aug, targets_aug, or mds_aug as needed ... # In this example, we just add a new metadata field ... mds_aug = [] ... for i, md in enumerate(mds): ... mds_aug.append(EnrichedDatumMetadata(**md, new_key=i)) ... return inputs_aug, targets_aug, mds_aug We can typehint an instance of the above class as an Augmentation in the image_classification domain: >>> from maite.protocols import image_classification as ic >>> im_aug: ic.Augmentation = ImageAugmentation(aug_name = 'an_augmentation')
print(od.Augmentation.__doc__)
An augmentation protocol for the object detection subproblem. An augmentation is expected to take a batch of data and return a modified version of that batch. Implementers must provide a single method that takes and returns a labeled data batch, where a labeled data batch is represented by a tuple of typesSequence[ArrayLike]
,Sequence[ObjectDetectionTarget]
, andSequence[DatumMetadataType]
. These correspond to the model input batch type, model target batch type, and datum-level metadata batch type, respectively. Methods ------- __call__(datum: Tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]]) -> Tuple[Sequence[ArrayLike], Sequence[ObjectDetectionTarget], Sequence[DatumMetadataType]] Return a modified version of original data batch. A data batch is represented by a tuple of model input batch (asSequence ArrayLike
with elements of shape(C, H, W)
), model target batch (asSequence[ObjectDetectionTarget]
), and batch metadata (asSequence[DatumMetadataType]
), respectively. Attributes ---------- metadata : AugmentationMetadata A typed dictionary containing at least an 'id' field of type str Examples -------- We can write an implementer of the augmentation class as either a function or a class. The only requirement is that the object provide a__call__
method that takes objects at least as general as the types promised in the protocol signature and return types at least as specific. We create a dummy set of data and use it to create a class the implements this lightweight protocol and augments the data: >>> import numpy as np >>> np.random.seed(1) >>> import copy >>> from dataclasses import dataclass >>> from typing import Any >>> from maite.protocols import AugmentationMetadata, object_detection as od First, we specify parameters that will be used to create the dummy dataset. >>> N_DATAPOINTS = 3 # datapoints in dataset >>> N_CLASSES = 2 # possible classes that can be detected >>> C = 3 # number of color channels >>> H = 10 # img height >>> W = 10 # img width Next, we create the input data to be used by the Augmentation. In this example, we create the following batch of object detection data: •xb
is the input batch data. Our batch will includeN_DATAPOINTS
number of samples. Note that we initialize all of the data to zeros in this example to demonstrate the augmentations better. •yb
is the object detection target data, which in this example represents zero object detections for each input datum (by having empty bounding boxes and class labels and scores). •mdb
is the associated metadata for each input datum. >>> @dataclass ... class MyObjectDetectionTarget: ... boxes: np.ndarray ... labels: np.ndarray ... scores: np.ndarray >>> xb: od.InputBatchType = list(np.zeros((N_DATAPOINTS, C, H, W))) >>> yb: od.TargetBatchType = list( ... MyObjectDetectionTarget(boxes=np.empty(0), labels=np.empty(0), scores=np.empty(0)) ... for _ in range(N_DATAPOINTS) ... ) >>> mdb: od.DatumMetadataBatchType = list({"id": i} for i in range(N_DATAPOINTS)) >>> # Display the first datum in batch, first color channel, and only first 5 rows and cols >>> np.array(xb[0])[0][:5, :5] array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]) Now we create the Augmentation, which will apply random noise (rounded to 3 decimal places) to the input data using thenumpy.random.random
andnp.round
functions. >>> np_noise = lambda shape: np.round(np.random.random(shape), 3) >>> class ImageAugmentation: ... def __init__(self, aug_func: Any, metadata: AugmentationMetadata): ... self.aug_func = aug_func ... self.metadata = metadata ... def __call__( ... self, ... batch: tuple[od.InputBatchType, od.TargetBatchType, od.DatumMetadataBatchType], ... ) -> tuple[od.InputBatchType, od.TargetBatchType, od.DatumMetadataBatchType]: ... xb, yb, mdb = batch ... # Copy data passed into the constructor to avoid mutating original inputs ... xb_aug = [copy.copy(input) for input in xb] ... # Add random noise to the input batch data, xb ... # (Note that all batch data dimensions (shapes) are the same in this example) ... shape = np.array(xb[0]).shape ... xb_aug = [x + self.aug_func(shape) for x in xb] ... # Note that this example augmentation only affects inputs--not targets ... return xb_aug, yb, mdb We can typehint an instance of the above class as an Augmentation in the object detection domain: >>> noise: od.Augmentation = ImageAugmentation(np_noise, metadata={"id": "np_rand_noise"}) Now we can apply thenoise
augmentation to our 3-tuple batch of data. Recall that our data was initialized to all zeros, so any non-zero values in the augmented data is a result of the augmentation. >>> xb_aug, yb_aug, mdb_aug = noise((xb, yb, mdb)) >>> # Display the first datum in batch, first color channel, and only first 5 rows and cols >>> np.array(xb_aug[0])[0][:5, :5] array([[0.417, 0.72 , 0. , 0.302, 0.147], [0.419, 0.685, 0.204, 0.878, 0.027], [0.801, 0.968, 0.313, 0.692, 0.876], [0.098, 0.421, 0.958, 0.533, 0.692], [0.989, 0.748, 0.28 , 0.789, 0.103]])
6 Metrics#
The Metric
protocol is inspired by the design of existing libraries
like Torchmetrics and Torcheval. The update
method operates on
batches of predictions and truth labels by either caching them for later
computation of the metric (via compute
) or updating sufficient
statistics in an online fashion.
print(ic.Metric.__doc__)
A metric protocol for the image classification ML subproblem.
A metric in this sense is expected to measure the level of agreement between model
predictions and ground-truth labels.
Methods
-------
update(preds: Sequence[ArrayLike], targets: Sequence[ArrayLike]) -> None
Add predictions and targets to metric's cache for later calculation. Both
preds and targets are expected to be sequences with elements of shape (Cl,)
.
compute() -> dict[str, Any]
Compute metric value(s) for currently cached predictions and targets, returned as
a dictionary.
reset() -> None
Clear contents of current metric's cache of predictions and targets.
Attributes
----------
metadata : MetricMetadata
A typed dictionary containing at least an 'id' field of type str
Examples
--------
Create a basic accuracy metric and test it on a small example dataset:
>>> from typing import Any, Sequence
>>> import numpy as np
>>> from maite.protocols import ArrayLike
>>> from maite.protocols import image_classification as ic
>>> class MyAccuracy:
... metadata: MetricMetadata = {'id': 'Example Multiclass Accuracy'}
...
... def __init__(self):
... self._total = 0
... self._correct = 0
...
... def reset(self) -> None:
... self._total = 0
... self._correct = 0
...
... def update(self, preds: Sequence[ArrayLike], targets: Sequence[ArrayLike]) -> None:
... model_probs = [np.array(r) for r in preds]
... true_onehot = [np.array(r) for r in targets]
...
... # Stack into single array, convert to class indices
... model_classes = np.vstack(model_probs).argmax(axis=1)
... truth_classes = np.vstack(true_onehot).argmax(axis=1)
...
... # Compare classes and update running counts
... same = (model_classes == truth_classes)
... self._total += len(same)
... self._correct += same.sum()
...
... def compute(self) -> dict[str, Any]:
... if self._total > 0:
... return {"accuracy": self._correct / self._total}
... else:
... raise Exception("No batches processed yet.")
Instantiate this class and typehint it as an image_classification.Metric.
By using typehinting, permits a static typechecker to check protocol compliance.
>>> accuracy: ic.Metric = MyAccuracy()
To use the metric call update() for each batch of predictions and truth values and call compute() to calculate the final metric values.
>>> # batch 1
>>> model_probs = [np.array([0.8, 0.1, 0.0, 0.1]), np.array([0.1, 0.2, 0.6, 0.1])] # predicted classes: 0, 2
>>> true_onehot = [np.array([1.0, 0.0, 0.0, 0.0]), np.array([0.0, 1.0, 0.0, 0.0])] # true classes: 0, 1
>>> accuracy.update(model_probs, true_onehot)
>>> print(accuracy.compute())
{'accuracy': 0.5}
>>>
>>> # batch 2
>>> model_probs = [np.array([0.1, 0.1, 0.7, 0.1]), np.array([0.0, 0.1, 0.0, 0.9])] # predicted classes: 2, 3
>>> true_onehot = [np.array([0.0, 0.0, 1.0, 0.0]), np.array([0.0, 0.0, 0.0, 1.0])] # true classes: 2, 3
>>> accuracy.update(model_probs, true_onehot)
>>>
>>> print(accuracy.compute())
{'accuracy': 0.75}
print(od.Metric.__doc__)
A metric protocol for the object detection ML subproblem. A metric in this sense is expected to measure the level of agreement between model predictions and ground-truth labels. Methods ------- update(preds: Sequence[ObjectDetectionTarget], targets: Sequence[ObjectDetectionTarget]) -> None Add predictions and targets to metric's cache for later calculation. compute() -> dict[str, Any] Compute metric value(s) for currently cached predictions and targets, returned as a dictionary. reset() -> None Clear contents of current metric's cache of predictions and targets. Attributes ---------- metadata : MetricMetadata A typed dictionary containing at least an 'id' field of type str Examples -------- Below, we write and test a class that implements the Metric protocol for object detection. For simplicity, the metric will compute the intersection over union (IoU) averaged over all predicted and associated ground truth boxes for a single image, and then take the mean over the per-image means. Note that when writing aMetric
implementer, return types may be narrower than the return types promised by the protocol, but the argument types must be at least as general as the argument types promised by the protocol. >>> from dataclasses import dataclass >>> from maite.protocols import ArrayLike, MetricMetadata >>> from typing import Any, Sequence >>> import maite.protocols.object_detection as od >>> import numpy as np >>> class MyIoUMetric: ... ... def __init__(self, id: str): ... self.pred_boxes = [] # elements correspond to predicted boxes in single image ... self.target_boxes = [] # elements correspond to ground truth boxes in single image ... # Store provided id for this metric instance ... self.metadata = MetricMetadata( ... id=id ... ) ... ... def reset(self) -> None: ... self.pred_boxes = [] ... self.target_boxes = [] ... ... def update( ... self, ... pred_batch: Sequence[od.ObjectDetectionTarget], ... target_batch: Sequence[od.ObjectDetectionTarget], ... ) -> None: ... self.pred_boxes.extend(pred_batch) ... self.target_boxes.extend(target_batch) ... ... @staticmethod ... def iou_vec(boxes_a: ArrayLike, boxes_b: ArrayLike) -> np.ndarray: ... # Break up points into separate columns ... x0a, y0a, x1a, y1a = np.split(boxes_a, 4, axis=1) ... x0b, y0b, x1b, y1b = np.split(boxes_b, 4, axis=1) ... # Calculate intersections ... xi_0, yi_0 = np.split( ... np.maximum(np.append(x0a, y0a, axis=1), np.append(x0b, y0b, axis=1)), ... 2, ... axis=1, ... ) ... xi_1, yi_1 = np.split( ... np.minimum(np.append(x1a, y1a, axis=1), np.append(x1b, y1b, axis=1)), ... 2, ... axis=1, ... ) ... ints: np.ndarray = np.maximum(0, xi_1 - xi_0) * np.maximum(0, yi_1 - yi_0) ... # Calculate unions (as sum of areas minus their intersection) ... unions: np.ndarray = ( ... (x1a - x0a) * (y1a - y0a) ... + (x1b - x0b) * (y1b - y0b) ... - (xi_1 - xi_0) * (yi_1 - yi_0) ... ) ... return ints / unions ... ... def compute(self) -> dict[str, Any]: ... mean_iou_by_img: list[float] = [] ... for pred_box, tgt_box in zip(self.pred_boxes, self.target_boxes): ... single_img_ious = self.iou_vec(pred_box.boxes, tgt_box.boxes) ... mean_iou_by_img.append(float(np.mean(single_img_ious))) ... return {"mean_iou": np.mean(np.array(mean_iou_by_img))} ... Now we can instantiate our IoU Metric class: >>> iou_metric: od.Metric = MyIoUMetric(id="IoUMetric") To use the metric, we populate two lists that encode predicted object detections and ground truth object detections for a single image. (Ordinarily, predictions would be produced by a model.) >>> prediction_boxes: list[tuple[int, int, int, int]] = [ ... (1, 1, 12, 12), ... (100, 100, 120, 120), ... (180, 180, 270, 270), ... ] >>> target_boxes: list[tuple[int, int, int, int]] = [ ... (1, 1, 10, 10), ... (100, 100, 120, 120), ... (200, 200, 300, 300), ... ] The MAITE Metric protocol requires thepred_batch
andtarget_batch
arguments to theupdate
method to be assignable to Sequence[ObjectDetectionTarget] (where ObjectDetectionTarget encodes detections in a single image). We define an implementation of ObjectDetectionTarget and use it to pass ground truth and predicted detections. >>> @dataclass ... class ObjectDetectionTargetImpl: ... boxes: np.ndarray ... labels: np.ndarray ... scores: np.ndarray >>> num_boxes = len(target_boxes) >>> fake_labels = np.random.randint(0, 9, num_boxes) >>> fake_scores = np.zeros(num_boxes) >>> pred_batch = [ ... ObjectDetectionTargetImpl( ... boxes=np.array(prediction_boxes), labels=fake_labels, scores=fake_scores ... ) ... ] >>> target_batch: Sequence[ObjectDetectionTargetImpl] = [ ... ObjectDetectionTargetImpl( ... boxes=np.array(target_boxes), labels=fake_labels, scores=fake_scores ... ) ... ] Finally, we callupdate
using this one-element batch, compute the metric value, and print it: >>> iou_metric.update(pred_batch, target_batch) >>> print(iou_metric.compute()) {'mean_iou': 0.6802112029384757}
7 Workflows#
MAITE provides high-level utilities for common workflows such as
evaluate
and predict
. They can be called with either
Dataset
s or DataLoader
s, and with optional Augmentation
.
The evaluate
function can optionally return the model predictions
and (potentially-augmented) data batches used during inference.
The predict
function returns the model predictions and
(potentially-augmented) data batches used during inference, essentially
calling evaluate
with a dummy metric.
print(evaluate.__doc__)
Evaluate a model's performance on data according to some metric with optional augmentation.
Some data source (either a dataloader or a dataset) must be provided
or an InvalidArgument exception is raised.
Parameters
----------
model : SomeModel
Maite Model object.
metric : SomeMetric | None, (default=None)
Compatible maite Metric.
dataloader : SomeDataloader | None, (default=None)
Compatible maite dataloader.
dataset : SomeDataset | None, (default=None)
Compatible maite dataset.
batch_size : int, (default=1)
Batch size for use with dataset (ignored if dataset=None).
augmentation : SomeAugmentation | None, (default=None)
Compatible maite augmentation.
return_augmented_data : bool, (default=False)
Set to True to return post-augmentation data as a function output.
return_preds : bool, (default=False)
Set to True to return raw predictions as a function output.
Returns
-------
tuple[dict[str, Any], Sequence[TargetType], Sequence[tuple[InputBatchType, TargetBatchType, DatumMetadataBatchType]]]
Tuple of returned metric value, sequence of model predictions, and
sequence of data batch tuples fed to the model during inference. The actual
types represented by InputBatchType, TargetBatchType, and DatumMetadataBatchType will vary
by the domain of the components provided as input arguments (e.g. image
classification or object detection.)
Note that the second and third return arguments will be empty if
return_augmented_data is False or return_preds is False, respectively.
print(predict.__doc__)
Make predictions for a given model & data source with optional augmentation.
Some data source (either a dataloader or a dataset) must be provided
or an InvalidArgument exception is raised.
Parameters
----------
model : SomeModel
Maite Model object.
dataloader : SomeDataloader | None, (default=None)
Compatible maite dataloader.
dataset : SomeDataset | None, (default=None)
Compatible maite dataset.
batch_size : int, (default=1)
Batch size for use with dataset (ignored if dataset=None).
augmentation : SomeAugmentation | None, (default=None)
Compatible maite augmentation.
Returns
-------
tuple[Sequence[SomeTargetBatchType], Sequence[tuple[SomeInputBatchType, SomeTargetBatchType, SomeMetadataBatchType]],
A tuple of the predictions (as a sequence of batches) and a sequence
of tuples containing the information associated with each batch.