image_classification.Dataset#
- class maite.protocols.image_classification.Dataset(*args, **kwargs)[source]#
A dataset protocol for image classification AI task providing datum-level data access.
Implementers must provide index lookup (via
__getitem__(ind: int)
method) and supportlen
(via__len__()
method). Data elements looked up this way correspond to individual examples (as opposed to batches).Indexing into or iterating over an image_classification dataset returns a
tuple
of typesArrayLike
,ArrayLike
, andDatumMetadataType
. These correspond to the model input type, model target type, and datum-level metadata, respectively.- Attributes:
- metadataDatasetMetadata
A typed dictionary containing at least an ‘id’ field of type str
Methods
__getitem__(ind: int) -> tuple[ArrayLike, ArrayLike, DatumMetadataType]
Provide map-style access to dataset elements. Returned tuple elements correspond to model input type, model target type, and datum-specific metadata type, respectively.
__len__() -> int
Return the number of data elements in the dataset.
Examples
We create a dummy set of data and use it to create a class that implements this lightweight dataset protocol:
>>> import numpy as np >>> from typing import Any, TypedDict >>> from maite.protocols import ArrayLike, DatasetMetadata
Assume we have 5 classes, 10 datapoints, and 10 target labels, and that we want to simply have an integer ‘id’ field in each datapoint’s metadata:
>>> N_CLASSES: int = 5 >>> N_DATUM: int = 10 >>> images: list[np.ndarray] = [np.random.rand(3, 32, 16) for _ in range(N_DATUM)] >>> targets: np.ndarray = np.eye(N_CLASSES)[np.random.choice(N_CLASSES, N_DATUM)]
We can type our datum metadata as a maite.protocols DatumMetadata, or define our own TypedDict with additional fields
>>> class MyDatumMetadata(DatumMetadata): ... hour_of_day: float ... >>> datum_metadata = [ MyDatumMetadata(id = i, hour_of_day=np.random.rand()*24) for i in range(N_DATUM) ]
Constructing a compliant dataset just involves a simple wrapper that fetches individual datapoints, where a datapoint is a single image, target, metadata 3-tuple.
>>> class ImageDataset: ... def __init__(self, ... dataset_name: str, ... index2label: dict[int,str], ... images: list[np.ndarray], ... targets: np.ndarray, ... datum_metadata: list[MyDatumMetadata]): ... self.images = images ... self.targets = targets ... self.metadata = DatasetMetadata({'id': dataset_name, 'index2label': index2label}) ... self._datum_metadata = datum_metadata ... def __len__(self) -> int: ... return len(images) ... def __getitem__(self, ind: int) -> tuple[np.ndarray, np.ndarray, MyDatumMetadata]: ... return self.images[ind], self.targets[ind], self._datum_metadata[ind]
We can instantiate this class and typehint it as an image_classification.Dataset. By using typehinting, we permit a static typechecker to verify protocol compliance.
>>> from maite.protocols import image_classification as ic >>> dataset: ic.Dataset = ImageDataset('a_dataset', ... {i: f"class_name_{i}" for i in range(N_CLASSES)}, ... images, targets, datum_metadata)
Note that when writing a Dataset implementer, return types may be narrower than the return types promised by the protocol (np.ndarray is a subtype of ArrayLike), but the argument types must be at least as general as the argument types promised by the protocol.