maite.protocols.image_classification.Metric#
- class maite.protocols.image_classification.Metric(*args, **kwargs)[source]#
A metric protocol for the image classification ML subproblem.
A metric in this sense is expected to measure the level of agreement between model predictions and ground-truth labels.
Examples
Create a basic accuracy metric and test it on a small example dataset:
>>> class MyAccuracy: ... metadata: MetricMetadata = {'id': 'Example Multiclass Accuracy'} ... ... def __init__(self): ... self._total = 0 ... self._correct = 0 ... ... def reset(self) -> None: ... self._total = 0 ... self._correct = 0 ... ... def update(self, preds: Sequence[ArrayLike], targets: Sequence[ArrayLike]) -> None: ... model_probs = [np.array(r) for r in preds] ... true_onehot = [np.array(r) for r in targets] ... ... # Stack into single array, convert to class indices ... model_classes = np.vstack(model_probs).argmax(axis=1) ... truth_classes = np.vstack(true_onehot).argmax(axis=1) ... ... # Compare classes and update running counts ... same = (model_classes == truth_classes) ... self._total += len(same) ... self._correct += same.sum() ... ... def compute(self) -> dict[str, Any]: ... if self._total > 0: ... return {"accuracy": self._correct / self._total} ... else: ... raise Exception("No batches processed yet.")
Instantiate this class and typehint it as an image_classification.Metric. By using typehinting, permits a static typechecker to check protocol compliance.
>>> accuracy: ic.Metric = MyAccuracy()
To use the metric call update() for each batch of predictions and truth values and call compute() to calculate the final metric values.
>>> # batch 1 >>> model_probs = [np.array([0.8, 0.1, 0.0, 0.1]), np.array([0.1, 0.2, 0.6, 0.1])] # predicted classes: 0, 2 >>> true_onehot = [np.array([1.0, 0.0, 0.0, 0.0]), np.array([0.0, 1.0, 0.0, 0.0])] # true classes: 0, 1 >>> accuracy.update(model_probs, true_onehot) >>> print(accuracy.compute()) {'accuracy': 0.5} >>> >>> # batch 2 >>> model_probs = [np.array([0.1, 0.1, 0.7, 0.1]), np.array([0.0, 0.1, 0.0, 0.9])] # predicted classes: 2, 3 >>> true_onehot = [np.array([0.0, 0.0, 1.0, 0.0]), np.array([0.0, 0.0, 0.0, 1.0])] # true classes: 2, 3 >>> accuracy.update(model_probs, true_onehot) >>> >>> print(accuracy.compute()) {'accuracy': 0.75}
- Attributes:
- metadataMetricMetadata
A typed dictionary containing at least an ‘id’ field of type str
Methods
update(preds: Sequence[ArrayLike], targets: Sequence[ArrayLike]) -> None
Add predictions and targets to metric’s cache for later calculation. Both preds and targets are expected to be sequences with elements of shape
(Cl,)
.compute() -> dict[str, Any]
Compute metric value(s) for currently cached predictions and targets, returned as a dictionary.
reset() -> None
Clear contents of current metric’s cache of predictions and targets.