image_classification.Metric#

class maite.protocols.image_classification.Metric(*args, **kwargs)[source]#

A metric protocol for the image classification AI task.

A metric in this sense is expected to measure the level of agreement between model predictions and ground-truth labels.

Attributes:
metadataMetricMetadata

A typed dictionary containing at least an ‘id’ field of type str

Methods

update(pred_batch: Sequence[ArrayLike], target_batch: Sequence[ArrayLike], metadata_batch: Sequence[DatumMetadataType]) -> None

Add predictions and targets (and metadata if applicable) to metric’s cache for later calculation. Both predictions and targets are expected to be sequences with elements of shape (Cl,).

compute() -> dict[str, Any]

Compute metric value(s) for currently cached predictions and targets, returned as a dictionary.

reset() -> None

Clear contents of current metric’s cache of predictions and targets.

Examples

Create a basic accuracy metric and test it on a small example dataset:

>>> from typing import Any, Sequence
>>> import numpy as np
>>> from maite.protocols import ArrayLike
>>> from maite.protocols import image_classification as ic
>>> class MyAccuracy:
...     metadata: MetricMetadata = {"id": "Example Multiclass Accuracy"}
...
...     def __init__(self):
...         self._total = 0
...         self._correct = 0
...
...     def reset(self) -> None:
...         self._total = 0
...         self._correct = 0
...
...     def update(
...         self,
...         pred_batch: Sequence[ArrayLike],
...         target_batch: Sequence[ArrayLike],
...         metadata_batch: Sequence[DatumMetadataType],
...     ) -> None:
...         model_preds = [np.array(r) for r in pred_batch]
...         true_onehot = [np.array(r) for r in target_batch]
...
...         # Stack into single array, convert to class indices
...         model_classes = np.vstack(model_preds).argmax(axis=1)
...         truth_classes = np.vstack(true_onehot).argmax(axis=1)
...
...         # Compare classes and update running counts
...         same = model_classes == truth_classes
...         self._total += len(same)
...         self._correct += same.sum().item()
...
...     def compute(self) -> dict[str, Any]:
...         if self._total > 0:
...             return {"accuracy": self._correct / self._total}
...         else:
...             raise Exception("No batches processed yet.")

Instantiate this class and typehint it as an image_classification.Metric. By using typehinting, permits a static typechecker to check protocol compliance.

>>> accuracy: ic.Metric = MyAccuracy()

To use the metric call update() for each batch of predictions and truth values and call compute() to calculate the final metric values.

>>> # batch 1
>>> model_preds = [
...     np.array([0.8, 0.1, 0.0, 0.1]),
...     np.array([0.1, 0.2, 0.6, 0.1]),
... ]  # predicted classes: 0, 2
>>> true_onehot = [
...     np.array([1.0, 0.0, 0.0, 0.0]),
...     np.array([0.0, 1.0, 0.0, 0.0]),
... ]  # true classes: 0, 1
>>> metadatas: list[ic.DatumMetadataType] = [{"id": 1}, {"id": 2}]
>>> accuracy.update(model_preds, true_onehot, metadatas)
>>> print(accuracy.compute())
{'accuracy': 0.5}
>>>
>>> # batch 2
>>> model_preds = [
...     np.array([0.1, 0.1, 0.7, 0.1]),
...     np.array([0.0, 0.1, 0.0, 0.9]),
... ]  # predicted classes: 2, 3
>>> true_onehot = [
...     np.array([0.0, 0.0, 1.0, 0.0]),
...     np.array([0.0, 0.0, 0.0, 1.0]),
... ]  # true classes: 2, 3
>>> metadatas: list[ic.DatumMetadataType] = [{"id": 3}, {"id": 4}]
>>> accuracy.update(model_preds, true_onehot, metadatas)
>>>
>>> print(accuracy.compute())
{'accuracy': 0.75}