metrics.torchmetrics.TMClassificationMetric#

class maite.interop.metrics.torchmetrics.TMClassificationMetric(metric, output_key=None, output_transform=None, device=None, dtype=None, metadata=None)[source]#

MAITE-compliant Wrapper for TorchMetrics Classification Metrics.

TMClassificationMetric is a wrapper around torchmetrics.classification metrics adhering to MAITE’s image_classification.Metric protocol.

Notes

Only Multiclass metrics are currently supported. Please refer to the torchmetrics documentation for more information: https://lightning.ai/docs/torchmetrics/stable/.

Supported metrics:

(Unsupported metrics either require preds and targets to be of varying different shapes or require additional parameters in their update() function signature.)

Examples

>>> import torch
>>> import torchmetrics.classification
>>> from typing_extensions import Sequence
>>> from maite.protocols import image_classification as ic, MetricMetadata
>>> from maite.interop.metrics.torchmetrics import TMClassificationMetric
>>>
>>> preds: Sequence[ic.TargetType] = [torch.tensor([0.1, 0.8, 0.1]), torch.tensor([0.6, 0.2, 0.2]), torch.tensor([0.4, 0.3, 0.3])]
>>> target: Sequence[ic.TargetType] = [torch.tensor([0, 1, 0]), torch.tensor([1, 0, 0]), torch.tensor([0, 0, 1])]
>>>
>>> # Create native TorchMetrics metric
>>> classification_metric = torchmetrics.classification.MulticlassAccuracy(num_classes=3)
>>>
>>> # Add additional field to base MetricMetadata
>>> class MyMetricMetadata(MetricMetadata):
...     num_classes: int
>>> metadata: MyMetricMetadata = {"id": "Multiclass Accuracy", "num_classes": 3}
>>>
>>> # Wrap metric and apply to sample data
>>> wrapped_classification_metric: ic.Metric = TMClassificationMetric(
...     classification_metric, metadata=metadata
... )
>>> wrapped_classification_metric.update(preds, target)
>>> result = wrapped_classification_metric.compute()
>>> result  
{'MulticlassAccuracy': tensor(0.6667)}
>>> print(f"{result['MulticlassAccuracy'].item():0.3f}")  # consistent formatting for doctest
0.667
Parameters:
metrictorchmetrics.Metric

Metric being wrapped. The metric must be a torchmetrics.classification metric.

output_keystr, optional

Outermost key returned from calling TMClassificationMetric.compute(). If neither output_key nor output_transform is provided and the metric result is not already a dictionary with str keys, then the name of the torchmetrics.Metric will be used as the default outmost key. Note: At most one of output_key and output_transform may be provided.

output_transformCallable[[torch.Tensor], dict[str, Any]], optional

Function that takes the output of torchmetrics.Metric.compute() as input and returns a modified version of it to be returned by TMClassificationMetric.compute(). Note: At most one of output_key and output_transform may be provided.

deviceAny, optional

Torch device type on which a torch.Tensor is or will be allocated. If none is passed, then the device will be inferred by torch.

dtypetorch.dtype, optional

Torch data type. If none is passed, then data type will be inferred by torch.

metadataMetricMetadata

A typed dictionary containing at least an ‘id’ field of type str.