maite.protocols.object_detection.Model#
- class maite.protocols.object_detection.Model(*args, **kwargs)[source]#
A model protocol for the object detection ML subproblem.
Implementers must provide a
__call__
method that operates on a batch of model inputs (asSequence[ArrayLike]`s) and returns a batch of model targets (as `Sequence[ObjectDetectionTarget]
)Examples
We create a simple MAITE-compliant object detection model, note it is a dummy model.
We define an object detection target dataclass as the output of the object detection model
>>> @dataclass ... class MyObjectDetectionTarget: ... boxes: np.ndarray ... labels: np.ndarray ... scores: np.ndarray ...
Specify parameters that will be used to create a dummy dataset.
>>> N_DATAPOINTS = 2 # datapoints in dataset >>> N_CLASSES = 5 # possible classes that can be detected >>> C = 3 # number of color channels >>> H = 32 # img height >>> W = 32 # img width
Now create a batch of data to form the inputs of the MAITE’s object detection model.
We define a simple object detection model, note here there is not an actual object detection model. In the __call__ method, it just outputs the MyObjectDetectionTarget.
>>> class ObjectDetectionDummyModel: ... metadata: ModelMetadata = {"id": "ObjectDetectionDummyModel"} ... ... def __call__( ... self, batch: od.InputBatchType ... ) -> Sequence[MyObjectDetectionTarget]: ... # For the simplicity, we don't provide an object detection model here, but the output from a model. ... DETECTIONS_PER_IMG = ( ... 2 # number of bounding boxes detections per image/datapoints ... ) ... all_boxes = np.array( ... [[1, 3, 5, 9], [2, 5, 8, 12], [4, 10, 8, 20], [3, 5, 6, 15]] ... ) # all detection boxes for N_DATAPOINTS ... all_predictions = list() ... for datum_idx in range(N_DATAPOINTS): ... boxes = all_boxes[datum_idx : datum_idx + DETECTIONS_PER_IMG] ... labels = np.random.randint(N_CLASSES, size=DETECTIONS_PER_IMG) ... scores = np.random.rand(DETECTIONS_PER_IMG) ... predictions = MyObjectDetectionTarget(boxes, labels, scores) ... all_predictions.append(predictions) ... return all_predictions ...
We can instantiate this class and typehint it as a maite object detection model. By using typehinting, we permit a static typechecker to verify protocol compliance.
>>> od_dummy_model: od.Model = ObjectDetectionDummyModel() >>> od_dummy_model.metadata {'id': 'ObjectDetectionDummyModel'} >>> predictions = od_dummy_model(simple_batch) >>> predictions [MyObjectDetectionTarget(boxes=array([[ 1, 3, 5, 9], [ 2, 5, 8, 12]]), labels=array([..., ...]), scores=array([..., ...])), MyObjectDetectionTarget(boxes=array([[ 2, 5, 8, 12], [ 4, 10, 8, 20]]), labels=array([..., ...]), scores=array([..., ...]))]
- Attributes:
- metadataModelMetadata
A typed dictionary containing at least an ‘id’ field of type str
Methods
__call__(input_batch: Sequence[ArrayLike]) -> Sequence[ObjectDetectionTarget]
Make a model prediction for inputs in input batch. Elements of input batch are expected in the shape
(C, H, W)
.