metrics.torchmetrics.TMClassificationMetric#
- class maite.interop.metrics.torchmetrics.TMClassificationMetric(metric, output_key=None, output_transform=None, device=None, dtype=None, metadata=None)[source]#
MAITE-compliant Wrapper for TorchMetrics Classification Metrics.
TMClassificationMetric
is a wrapper aroundtorchmetrics.classification
metrics adhering to MAITE’simage_classification.Metric
protocol.Notes
Only
Multiclass
metrics are currently supported. Please refer to thetorchmetrics
documentation for more information: https://lightning.ai/docs/torchmetrics/stable/.- Supported metrics:
(Unsupported metrics either require
preds
andtargets
to be of varying different shapes or require additional parameters in theirupdate()
function signature.)Examples
>>> import torch >>> import torchmetrics.classification >>> from typing_extensions import Sequence >>> from maite.protocols import image_classification as ic, MetricMetadata >>> from maite.interop.metrics.torchmetrics import TMClassificationMetric >>> >>> preds: Sequence[ic.TargetType] = [torch.tensor([0.1, 0.8, 0.1]), torch.tensor([0.6, 0.2, 0.2]), torch.tensor([0.4, 0.3, 0.3])] >>> target: Sequence[ic.TargetType] = [torch.tensor([0, 1, 0]), torch.tensor([1, 0, 0]), torch.tensor([0, 0, 1])] >>> >>> # Create native TorchMetrics metric >>> classification_metric = torchmetrics.classification.MulticlassAccuracy(num_classes=3) >>> >>> # Add additional field to base MetricMetadata >>> class MyMetricMetadata(MetricMetadata): ... num_classes: int >>> metadata: MyMetricMetadata = {"id": "Multiclass Accuracy", "num_classes": 3} >>> >>> # Wrap metric and apply to sample data >>> wrapped_classification_metric: ic.Metric = TMClassificationMetric( ... classification_metric, metadata=metadata ... ) >>> wrapped_classification_metric.update(preds, target) >>> result = wrapped_classification_metric.compute() >>> result {'MulticlassAccuracy': tensor(0.6667)} >>> print(f"{result['MulticlassAccuracy'].item():0.3f}") # consistent formatting for doctest 0.667
- Parameters:
- metric
torchmetrics.Metric
Metric being wrapped. The metric must be a
torchmetrics.classification
metric.- output_keystr, optional
Outermost key returned from calling
TMClassificationMetric.compute()
. If neither output_key nor output_transform is provided and the metric result is not already a dictionary with str keys, then the name of thetorchmetrics.Metric
will be used as the default outmost key. Note: At most one of output_key and output_transform may be provided.- output_transformCallable[[torch.Tensor], dict[str, Any]], optional
Function that takes the output of
torchmetrics.Metric.compute()
as input and returns a modified version of it to be returned byTMClassificationMetric.compute()
. Note: At most one of output_key and output_transform may be provided.- deviceAny, optional
Torch device type on which a torch.Tensor is or will be allocated. If none is passed, then the device will be inferred by torch.
- dtype
torch.dtype
, optional Torch data type. If none is passed, then data type will be inferred by torch.
- metadata
MetricMetadata
A typed dictionary containing at least an ‘id’ field of type str.
- metric