metrics.torchmetrics.TMDetectionMetric#

class maite.interop.metrics.torchmetrics.TMDetectionMetric(metric, output_key=None, output_transform=None, device=None, dtype=None, metadata=None)[source]#

MAITE-compliant Wrapper for Torchmetrics Detection Metrics.

TMDetectionMetric is a wrapper around torchmetrics.detection metrics adhering to MAITE’s maite.protocols.object_detection.Metric protocol.

CompleteIntersectionOverUnion, DistanceIntersectionOverUnion, GeneralizedIntersectionOverUnion, IntersectionOverUnion, and MeanAveragePrecision with iou_type='bbox' are supported.

Examples

>>> from typing import Sequence
>>> import torch
>>> import torchmetrics.detection
>>> from maite.protocols import object_detection as od
>>> from maite.interop.torchmetrics import TMDetectionMetric
>>> from dataclasses import dataclass
>>> @dataclass
... class ObjectDetectionTarget_Impl:
...     boxes: torch.Tensor
...     labels: torch.Tensor
...     scores: torch.Tensor
>>> detection_metric = torchmetrics.detection.MeanAveragePrecision(iou_type="bbox")
>>> output_transform = lambda x: x["map_50"]
>>> wrapped_detect_metric: od.Metric = TMDetectionMetric(
...     detection_metric, output_key="MAP", output_transform=output_transform
... )
>>> preds: Sequence[od.TargetType] = [
...     ObjectDetectionTarget_Impl(
...         boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]),
...         labels=torch.Tensor([0]),
...         scores=torch.Tensor([0.536]),
...     )
... ]
>>> targets: Sequence[od.TargetType] = [
...     ObjectDetectionTarget_Impl(
...         boxes=torch.Tensor([[214.0, 41.0, 562.0, 285.0]]),
...         labels=torch.Tensor([0]),
...         scores=torch.Tensor([1.0]),
...     )
... ]
>>> wrapped_detect_metric.update(preds, targets)
>>> results = wrapped_detect_metric.compute()
>>> print(results)
{'MAP': tensor(1.)}
Parameters:
metrictorchmetrics.Metric

Metric being wrapped. The metric must be a torchmetrics.detection metric.

output_keystr, optional

Outermost key returned from calling TMDetectionMetric.compute(). If none is provided, and the result of torchmetrics.Metric.compute() or output_transform (if provided) is not a dictionary with a string as its key, then the value will default to the name of the torchmetrics.Metric.

output_transformCallable[dict[str, torch.Tensor]]

Function that takes the output of torchmetrics.Metric.compute() as input and returns a modified version of it.

deviceAny, optional

Torch device type on which a torch.Tensor is or will be allocated. If none is passed, then the device will be inferred by torch.

dtypetorch.dtype, optional

Torch data type. If none is passed, then data type will be inferred by torch.

metadataMetricMetadata

A typed dictionary containing at least an ‘id’ field of type str.

Raises:
ValueError

If an unsupported model is given.