maite.protocols.image_classification.Metric#

class maite.protocols.image_classification.Metric(*args, **kwargs)[source]#

A metric protocol for the image classification ML subproblem.

A metric in this sense is expected to measure the level of agreement between model predictions and ground-truth labels.

Examples

Create a basic accuracy metric and test it on a small example dataset:

>>> from typing import Any, Sequence
>>> import numpy as np
>>> from maite.protocols import ArrayLike
>>> from maite.protocols import image_classification as ic
>>> class MyAccuracy:
...    metadata: MetricMetadata = {'id': 'Example Multiclass Accuracy'}
...
...    def __init__(self):
...        self._total = 0
...        self._correct = 0
...
...    def reset(self) -> None:
...        self._total = 0
...        self._correct = 0
...
...    def update(self, preds: Sequence[ArrayLike], targets: Sequence[ArrayLike]) -> None:
...        model_probs = [np.array(r) for r in preds]
...        true_onehot = [np.array(r) for r in targets]
...
...        # Stack into single array, convert to class indices
...        model_classes = np.vstack(model_probs).argmax(axis=1)
...        truth_classes = np.vstack(true_onehot).argmax(axis=1)
...
...        # Compare classes and update running counts
...        same = (model_classes == truth_classes)
...        self._total += len(same)
...        self._correct += same.sum()
...
...    def compute(self) -> dict[str, Any]:
...        if self._total > 0:
...            return {"accuracy": self._correct / self._total}
...        else:
...            raise Exception("No batches processed yet.")

Instantiate this class and typehint it as an image_classification.Metric. By using typehinting, permits a static typechecker to check protocol compliance.

>>> accuracy: ic.Metric = MyAccuracy()

To use the metric call update() for each batch of predictions and truth values and call compute() to calculate the final metric values.

>>> # batch 1
>>> model_probs = [np.array([0.8, 0.1, 0.0, 0.1]), np.array([0.1, 0.2, 0.6, 0.1])] # predicted classes: 0, 2
>>> true_onehot = [np.array([1.0, 0.0, 0.0, 0.0]), np.array([0.0, 1.0, 0.0, 0.0])] # true classes: 0, 1
>>> accuracy.update(model_probs, true_onehot)
>>> print(accuracy.compute())
{'accuracy': 0.5}
>>>
>>> # batch 2
>>> model_probs = [np.array([0.1, 0.1, 0.7, 0.1]), np.array([0.0, 0.1, 0.0, 0.9])] # predicted classes: 2, 3
>>> true_onehot = [np.array([0.0, 0.0, 1.0, 0.0]), np.array([0.0, 0.0, 0.0, 1.0])] # true classes: 2, 3
>>> accuracy.update(model_probs, true_onehot)
>>>
>>> print(accuracy.compute())
{'accuracy': 0.75}
Attributes:
metadataMetricMetadata

A typed dictionary containing at least an ‘id’ field of type str

Methods

update(preds: Sequence[ArrayLike], targets: Sequence[ArrayLike]) -> None

Add predictions and targets to metric’s cache for later calculation. Both preds and targets are expected to be sequences with elements of shape (Cl,).

compute() -> dict[str, Any]

Compute metric value(s) for currently cached predictions and targets, returned as a dictionary.

reset() -> None

Clear contents of current metric’s cache of predictions and targets.