metrics.torchmetrics.TMClassificationMetric#
- class maite.interop.metrics.torchmetrics.TMClassificationMetric(metric, output_key=None, output_transform=None, device=None, dtype=None, metadata=None)[source]#
MAITE-compliant Wrapper for TorchMetrics Classification Metrics.
TMClassificationMetricis a wrapper aroundtorchmetrics.classificationmetrics adhering to MAITE’simage_classification.Metricprotocol.Notes
Only
Multiclassmetrics are currently supported. Please refer to thetorchmetricsdocumentation for more information: https://lightning.ai/docs/torchmetrics/stable/.- Supported metrics:
(Unsupported metrics either require
predsandtargetsto be of varying different shapes or require additional parameters in theirupdate()function signature.)Examples
>>> import torch >>> import torchmetrics.classification >>> from typing_extensions import Sequence >>> from maite.protocols import image_classification as ic, MetricMetadata >>> from maite.interop.metrics.torchmetrics import TMClassificationMetric >>> >>> preds: Sequence[ic.TargetType] = [ ... torch.tensor([0.1, 0.8, 0.1]), ... torch.tensor([0.6, 0.2, 0.2]), ... torch.tensor([0.4, 0.3, 0.3]), ... ] >>> target: Sequence[ic.TargetType] = [ ... torch.tensor([0, 1, 0]), ... torch.tensor([1, 0, 0]), ... torch.tensor([0, 0, 1]), ... ] >>> metadatas: Sequence[ic.DatumMetadataType] = [{"id": 1}, {"id": 2}, {"id": 3}] >>> # Create native TorchMetrics metric >>> classification_metric = torchmetrics.classification.MulticlassAccuracy( ... num_classes=3 ... ) >>> >>> # Add additional field to base MetricMetadata >>> class MyMetricMetadata(MetricMetadata): ... num_classes: int >>> metadata: MyMetricMetadata = {"id": "Multiclass Accuracy", "num_classes": 3} >>> >>> # Wrap metric and apply to sample data >>> wrapped_classification_metric: ic.Metric = TMClassificationMetric( ... classification_metric, metadata=metadata ... ) >>> wrapped_classification_metric.update(preds, target, metadatas) >>> result = wrapped_classification_metric.compute() >>> result {'MulticlassAccuracy': tensor(0.6667)} >>> print( ... f"{result['MulticlassAccuracy'].item():0.3f}" ... ) # consistent formatting for doctest 0.667
- Parameters:
- metric
torchmetrics.Metric Metric being wrapped. The metric must be a
torchmetrics.classificationmetric.- output_keystr, optional
Outermost key returned from calling
TMClassificationMetric.compute(). If neither output_key nor output_transform is provided and the metric result is not already a dictionary with str keys, then the name of thetorchmetrics.Metricwill be used as the default outmost key. Note: At most one of output_key and output_transform may be provided.- output_transformCallable[[torch.Tensor], dict[str, Any]], optional
Function that takes the output of
torchmetrics.Metric.compute()as input and returns a modified version of it to be returned byTMClassificationMetric.compute(). Note: At most one of output_key and output_transform may be provided.- deviceAny, optional
Torch device type on which a torch.Tensor is or will be allocated. If none is passed, then the device will be inferred by torch.
- dtype
torch.dtype, optional Torch data type. If none is passed, then data type will be inferred by torch.
- metadata
MetricMetadata A typed dictionary containing at least an ‘id’ field of type str.
- metric