metrics.torchmetrics.TMDetectionMetric#
- class maite.interop.metrics.torchmetrics.TMDetectionMetric(metric, output_key=None, output_transform=None, device=None, dtype=None, metadata=None)[source]#
MAITE-compliant Wrapper for Torchmetrics Detection Metrics.
TMDetectionMetric
is a wrapper aroundtorchmetrics.detection
metrics adhering to MAITE’smaite.protocols.object_detection.Metric
protocol.CompleteIntersectionOverUnion
,DistanceIntersectionOverUnion
,GeneralizedIntersectionOverUnion
,IntersectionOverUnion
, andMeanAveragePrecision
withiou_type='bbox'
are supported.Examples
>>> from typing import Sequence >>> import torch >>> import torchmetrics.detection >>> from maite.protocols import object_detection as od >>> from maite.interop.torchmetrics import TMDetectionMetric >>> from dataclasses import dataclass >>> @dataclass ... class ObjectDetectionTarget_Impl: ... boxes: torch.Tensor ... labels: torch.Tensor ... scores: torch.Tensor >>> detection_metric = torchmetrics.detection.MeanAveragePrecision(iou_type="bbox") >>> output_transform = lambda x: x["map_50"] >>> wrapped_detect_metric: od.Metric = TMDetectionMetric( ... detection_metric, output_key="MAP", output_transform=output_transform ... ) >>> preds: Sequence[od.TargetType] = [ ... ObjectDetectionTarget_Impl( ... boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]), ... labels=torch.Tensor([0]), ... scores=torch.Tensor([0.536]), ... ) ... ] >>> targets: Sequence[od.TargetType] = [ ... ObjectDetectionTarget_Impl( ... boxes=torch.Tensor([[214.0, 41.0, 562.0, 285.0]]), ... labels=torch.Tensor([0]), ... scores=torch.Tensor([1.0]), ... ) ... ] >>> wrapped_detect_metric.update(preds, targets) >>> results = wrapped_detect_metric.compute() >>> print(results) {'MAP': tensor(1.)}
- Parameters:
- metric
torchmetrics.Metric
Metric being wrapped. The metric must be a
torchmetrics.detection
metric.- output_keystr, optional
Outermost key returned from calling
TMDetectionMetric.compute()
. If none is provided, and the result oftorchmetrics.Metric.compute()
oroutput_transform
(if provided) is not a dictionary with a string as its key, then the value will default to the name of thetorchmetrics.Metric
.- output_transformCallable[dict[str, torch.Tensor]]
Function that takes the output of
torchmetrics.Metric.compute()
as input and returns a modified version of it.- deviceAny, optional
Torch device type on which a torch.Tensor is or will be allocated. If none is passed, then the device will be inferred by torch.
- dtype
torch.dtype
, optional Torch data type. If none is passed, then data type will be inferred by torch.
- metadata
MetricMetadata
A typed dictionary containing at least an ‘id’ field of type str.
- metric
- Raises:
- ValueError
If an unsupported model is given.