image_classification.Metric#
- class maite.protocols.image_classification.Metric(*args, **kwargs)[source]#
A metric protocol for the image classification AI task.
A metric in this sense is expected to measure the level of agreement between model predictions and ground-truth labels.
- Attributes:
- metadataMetricMetadata
A typed dictionary containing at least an ‘id’ field of type str
Methods
update(pred_batch: Sequence[ArrayLike], target_batch: Sequence[ArrayLike], metadata_batch: Sequence[DatumMetadataType]) -> None
Add predictions and targets (and metadata if applicable) to metric’s cache for later calculation. Both predictions and targets are expected to be sequences with elements of shape
(Cl,)
.compute() -> dict[str, Any]
Compute metric value(s) for currently cached predictions and targets, returned as a dictionary.
reset() -> None
Clear contents of current metric’s cache of predictions and targets.
Examples
Create a basic accuracy metric and test it on a small example dataset:
>>> from typing import Any, Sequence >>> import numpy as np >>> from maite.protocols import ArrayLike >>> from maite.protocols import image_classification as ic
>>> class MyAccuracy: ... metadata: MetricMetadata = {"id": "Example Multiclass Accuracy"} ... ... def __init__(self): ... self._total = 0 ... self._correct = 0 ... ... def reset(self) -> None: ... self._total = 0 ... self._correct = 0 ... ... def update( ... self, ... pred_batch: Sequence[ArrayLike], ... target_batch: Sequence[ArrayLike], ... metadata_batch: Sequence[DatumMetadataType], ... ) -> None: ... model_preds = [np.array(r) for r in pred_batch] ... true_onehot = [np.array(r) for r in target_batch] ... ... # Stack into single array, convert to class indices ... model_classes = np.vstack(model_preds).argmax(axis=1) ... truth_classes = np.vstack(true_onehot).argmax(axis=1) ... ... # Compare classes and update running counts ... same = model_classes == truth_classes ... self._total += len(same) ... self._correct += same.sum().item() ... ... def compute(self) -> dict[str, Any]: ... if self._total > 0: ... return {"accuracy": self._correct / self._total} ... else: ... raise Exception("No batches processed yet.")
Instantiate this class and typehint it as an image_classification.Metric. By using typehinting, permits a static typechecker to check protocol compliance.
>>> accuracy: ic.Metric = MyAccuracy()
To use the metric call update() for each batch of predictions and truth values and call compute() to calculate the final metric values.
>>> # batch 1 >>> model_preds = [ ... np.array([0.8, 0.1, 0.0, 0.1]), ... np.array([0.1, 0.2, 0.6, 0.1]), ... ] # predicted classes: 0, 2 >>> true_onehot = [ ... np.array([1.0, 0.0, 0.0, 0.0]), ... np.array([0.0, 1.0, 0.0, 0.0]), ... ] # true classes: 0, 1 >>> metadatas: list[ic.DatumMetadataType] = [{"id": 1}, {"id": 2}] >>> accuracy.update(model_preds, true_onehot, metadatas) >>> print(accuracy.compute()) {'accuracy': 0.5} >>> >>> # batch 2 >>> model_preds = [ ... np.array([0.1, 0.1, 0.7, 0.1]), ... np.array([0.0, 0.1, 0.0, 0.9]), ... ] # predicted classes: 2, 3 >>> true_onehot = [ ... np.array([0.0, 0.0, 1.0, 0.0]), ... np.array([0.0, 0.0, 0.0, 1.0]), ... ] # true classes: 2, 3 >>> metadatas: list[ic.DatumMetadataType] = [{"id": 3}, {"id": 4}] >>> accuracy.update(model_preds, true_onehot, metadatas) >>> >>> print(accuracy.compute()) {'accuracy': 0.75}