maite.protocols.image_classification.Dataset#

class maite.protocols.image_classification.Dataset(*args, **kwargs)[source]#

A dataset protocol for image classification ML subproblem providing datum-level data access.

Implementers must provide index lookup (via __getitem__(ind: int) method) and support len (via __len__() method). Data elements looked up this way correspond to individual examples (as opposed to batches).

Indexing into or iterating over an image_classification dataset returns a tuple of types ArrayLike, ArrayLike, and DatumMetadataType. These correspond to the model input type, model target type, and datum-level metadata, respectively.

Examples

We create a dummy set of data and use it to create a class that implements this lightweight dataset protocol:

>>> import numpy as np
>>> from typing import Any, TypedDict
>>> from maite.protocols import ArrayLike, DatasetMetadata

Assume we have 5 classes, 10 datapoints, and 10 target labels, and that we want to simply have an integer ‘id’ field in each datapoint’s metadata:

>>> N_CLASSES: int = 5
>>> N_DATUM: int = 10
>>> images: list[np.ndarray] = [np.random.rand(3, 32, 16) for _ in range(N_DATUM)]
>>> targets: np.ndarray = np.eye(N_CLASSES)[np.random.choice(N_CLASSES, N_DATUM)]

We can type our datum metadata as a maite.protocols DatumMetadata, or define our own TypedDict with additional fields

>>> class MyDatumMetadata(DatumMetadata):
...     hour_of_day: float
...
>>> datum_metadata = [ MyDatumMetadata(id = i, hour_of_day=np.random.rand()*24) for i in range(N_DATUM) ]

Constructing a compliant dataset just involves a simple wrapper that fetches individual datapoints, where a datapoint is a single image, target, metadata 3-tuple.

>>> class ImageDataset:
...     def __init__(self,
...                  dataset_name: str,
...                  index2label: dict[int,str],
...                  images: list[np.ndarray],
...                  targets: np.ndarray,
...                  datum_metadata: list[MyDatumMetadata]):
...         self.images = images
...         self.targets = targets
...         self.metadata = DatasetMetadata({'id': dataset_name, 'index2label': index2label})
...         self._datum_metadata = datum_metadata
...     def __len__(self) -> int:
...         return len(images)
...     def __getitem__(self, ind: int) -> tuple[np.ndarray, np.ndarray, MyDatumMetadata]:
...         return self.images[ind], self.targets[ind], self._datum_metadata[ind]

We can instantiate this class and typehint it as an image_classification.Dataset. By using typehinting, we permit a static typechecker to verify protocol compliance.

>>> from maite.protocols import image_classification as ic
>>> dataset: ic.Dataset = ImageDataset('a_dataset',
...                                    {i: f"class_name_{i}" for i in range(N_CLASSES)},
...                                    images, targets, datum_metadata)

Note that when writing a Dataset implementer, return types may be narrower than the return types promised by the protocol (np.ndarray is a subtype of ArrayLike), but the argument types must be at least as general as the argument types promised by the protocol.

Attributes:
metadataDatasetMetadata

A typed dictionary containing at least an ‘id’ field of type str

Methods

__getitem__(ind: int) -> tuple[ArrayLike, ArrayLike, DatumMetadataType]

Provide map-style access to dataset elements. Returned tuple elements correspond to model input type, model target type, and datum-specific metadata type, respectively.

__len__() -> int

Return the number of data elements in the dataset.