image_classification.FieldwiseDataset#
- class maite.protocols.image_classification.FieldwiseDataset(*args, **kwargs)[source]#
A specialization of Dataset protocol (i.e., a subprotocol) that specifies additional accessor methods for getting input, target, and metadata individually.
Methods
__getitem__(ind: int) -> tuple[InputType, TargetType, DatumMetadataType]
Provide map-style access to dataset elements. Returned tuple elements correspond to model input type, model target type, and datum-specific metadata type, respectively.
__len__() -> int
Return the number of data elements in the dataset.
get_input(index: int, /) -> InputType:
Get input at the given index.
get_target(index: int, /) -> TargetType:
Get target at the given index.
get_metadata(index: int, /) -> DatumMetadataType:
Get metadata at the given index.
Examples
We create a dummy set of data and use it to create a class that implements the dataset protocol:
>>> import numpy as np >>> from typing import Any >>> from typing_extensions import TypedDict >>> from maite.protocols import ArrayLike, DatasetMetadata
Assume we have 5 classes, 10 datapoints, and 10 target labels, and that we want to simply have an integer ‘id’ field in each datapoint’s metadata:
>>> N_CLASSES: int = 5 >>> N_DATUM: int = 10 >>> images: list[np.ndarray] = [np.random.rand(3, 32, 16) for _ in range(N_DATUM)] >>> targets: np.ndarray = np.eye(N_CLASSES)[np.random.choice(N_CLASSES, N_DATUM)]
We can type our datum metadata as a maite.protocols DatumMetadata, or define our own TypedDict with additional fields
>>> class MyDatumMetadata(DatumMetadata): ... hour_of_day: float >>> datum_metadata = [ ... MyDatumMetadata(id=i, hour_of_day=np.random.rand() * 24) ... for i in range(N_DATUM) ... ]
Constructing a compliant dataset just involves a simple wrapper that fetches individual datapoints, where a datapoint is a single image, target, metadata 3-tuple.
>>> class ImageDataset: ... def __init__( ... self, ... dataset_name: str, ... index2label: dict[int, str], ... images: list[np.ndarray], ... targets: np.ndarray, ... datum_metadata: list[MyDatumMetadata], ... ): ... self.images = images ... self.targets = targets ... self.metadata = DatasetMetadata( ... {"id": dataset_name, "index2label": index2label} ... ) ... self._datum_metadata = datum_metadata ... ... def get_input(self, index, /) -> np.ndarray: ... return self.images[index] ... ... def get_target(self, index, /) -> np.ndarray: ... return self.targets[index] ... ... def get_metadata(self, index, /) -> MyDatumMetadata: ... return self._datum_metadata[index] ... ... def __len__(self) -> int: ... return len(images) ... ... def __getitem__( ... self, ind: int ... ) -> tuple[np.ndarray, np.ndarray, MyDatumMetadata]: ... return self.images[ind], self.targets[ind], self._datum_metadata[ind]
We can instantiate this class and typehint it as an image_classification.Dataset. By using typehinting, we permit a static typechecker to verify protocol compliance.
>>> from maite.protocols import image_classification as ic >>> dataset: ic.FieldwiseDataset = ImageDataset( ... "a_dataset", ... {i: f"class_name_{i}" for i in range(N_CLASSES)}, ... images, ... targets, ... datum_metadata, ... )